## Task augmentation

In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torchvision.models import resnet50
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torch.nn.functional as F
from tqdm.auto import tqdm,trange
from pathlib import Path
import pandas as pd
import torch.nn.utils.prune as prune
import numpy as np
import random
import learn2learn as l2l
import pickle
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
ways = 5 
shots =5

In [4]:
with open("./data/test.pkl",'rb') as f:
    test_data = pickle.load(f)
with open("./data/train.pkl",'rb') as f:
    train_data = pickle.load(f)
with open("./data/validation.pkl",'rb') as f:
    val_data = pickle.load(f)

In [5]:
train_data['images'].shape

(38400, 3, 84, 84)

In [6]:
val_data['images'].shape

(9600, 3, 84, 84)

# Task augmentation

In [7]:
%%time
images = train_data["images"]
labels = train_data["labels"]
num_labels = np.unique(labels).shape[0]
num_new_labels = num_labels*3
n_data = int(len(labels)/num_labels)
for i in range(num_labels):
    images_of_one_class = images[i*n_data:(i+1)*n_data] #600
    new_img1 = np.array([np.rot90(img.copy(), axes=(2,1)) for img in images_of_one_class]) 
    new_img2 = np.array([np.rot90(img.copy(),2, axes=(2,1)) for img in images_of_one_class])
    new_img3 = np.array([np.rot90(img.copy(),3, axes=(2,1)) for img in images_of_one_class])
    train_data["images"] = np.concatenate((train_data["images"],new_img1,new_img2,new_img3))
    train_data["labels"] = np.concatenate((train_data["labels"],[num_labels+3*i]*n_data,[num_labels+3*i+1]*n_data,[num_labels+3*i+2]*n_data))
    
    # print(train_data['labels'].shape)
    # break

CPU times: user 31 s, sys: 38.4 s, total: 1min 9s
Wall time: 1min 9s


In [8]:
train_data["images"].shape,train_data["labels"].shape

((153600, 3, 84, 84), (153600,))

In [11]:
%%time
images = val_data["images"]
labels = val_data["labels"]
num_labels = np.unique(labels).shape[0]
num_new_labels = num_labels*3
n_data = int(len(labels)/num_labels)
for i in range(num_labels):
    images_of_one_class = images[i*n_data:(i+1)*n_data]
    new_img1 = np.array([np.rot90(img.copy(), axes=(2,1)) for img in images_of_one_class]) 
    new_img2 = np.array([np.rot90(img.copy(),2, axes=(2,1)) for img in images_of_one_class])
    new_img3 = np.array([np.rot90(img.copy(),3, axes=(2,1)) for img in images_of_one_class])
    val_data["images"] = np.concatenate((val_data["images"],new_img1,new_img2,new_img3))
    val_data["labels"] = np.concatenate((val_data["labels"],[num_labels+3*i]*n_data,[num_labels+3*i+1]*n_data,[num_labels+3*i+2]*n_data))

CPU times: user 2.33 s, sys: 3.12 s, total: 5.45 s
Wall time: 5.45 s


In [12]:
val_data["images"].shape,val_data["labels"].shape

((38400, 3, 84, 84), (38400,))

In [13]:
class Dataset(Dataset):
    def __init__(self,data,transform = None):
        self.transform = transform
        self.data=data
        self.images = np.transpose(self.data['images'],(0,2,3,1))
        self.labels = data['labels']
    def __len__(self): 
        return len(self.data['images'])
    def __getitem__(self,index):
        img = self.images[index]
        label = self.labels[index]
        if self.transform:
            img = self.transform(img)
        return img, label

In [14]:
train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees = 0, translate=(0.1,0.1)),
    ])
val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees = 0, translate=(0.1,0.1))
])

In [15]:
train_dataset = Dataset(train_data,train_transform)
val_dataset = Dataset(val_data, val_transform)

In [16]:
train_dataset = l2l.data.MetaDataset(train_dataset)
val_dataset = l2l.data.MetaDataset(val_dataset)

In [17]:
task_transforms = [
    l2l.data.transforms.FusedNWaysKShots(train_dataset, n=ways, k=2*shots),
    l2l.data.transforms.LoadData(train_dataset),
    l2l.data.transforms.RemapLabels(train_dataset),
    l2l.data.transforms.ConsecutiveLabels(train_dataset),
    # l2l.vision.transforms.RandomClassRotation(train_dataset,degrees = [0,10,15,20])
]
train_taskset =  l2l.data.TaskDataset(train_dataset, task_transforms, num_tasks=20000)
task_transforms = [
    l2l.data.transforms.FusedNWaysKShots(val_dataset, n = ways, k=2*shots),
    l2l.data.transforms.LoadData(val_dataset),
    l2l.data.transforms.RemapLabels(val_dataset),
    l2l.data.transforms.ConsecutiveLabels(val_dataset),
    # l2l.vision.transforms.RandomClassRotation(val_dataset,degrees = [0,10,15,20])
]
val_taskset =  l2l.data.TaskDataset(val_dataset, task_transforms, num_tasks=20000)

In [19]:
def adaptation(task, learner, fas,loss_func, device):
    data, labels = task
    data = data.to(device)
    labels = labels.to(device)
    sup_mask = np.array(([False]*5+[True]*5)*5)
    query_mask = ~sup_mask
    sup_data , query_data = data[sup_mask], data[query_mask]
    sup_labels , query_labels = labels[sup_mask], labels[query_mask]
    for step in range(fas): # inner loop
        pred = learner(sup_data)
        train_loss = loss_func(pred, sup_labels)
        learner.adapt(train_loss)
    query_pred = learner(query_data)
    query_loss = loss_func(query_pred,query_labels)
    query_acc = (torch.argmax(query_pred,1)==query_labels).sum()/len(query_labels)
    return query_loss, query_acc

def train(EPOCHS, meta_model, train_taskset, val_taskset, loss_func, opt, device, model_name):
    log_file = f"./log/{model_name}.txt"
    res = {
        "train_acc_list" : [],
        "val_acc_list" : [],
        "train_loss_list" : [],
        "val_loss_list" : [],
    }
    last_val_loss=float("inf")
    for epoch in trange(EPOCHS):
        opt.zero_grad()
        train_loss = 0
        train_acc = 0
        val_loss = 0
        val_acc = 0
        for _ in range(meta_batch):
            learner = meta_model.clone(first_order = False)
            task = train_taskset.sample() # sample 一組 [50,3,84,84],[50]
            query_loss, query_acc = adaptation(task, learner,fas, loss_func, device)
            query_loss.backward()
            train_loss += query_loss.item()
            train_acc += query_acc.item()
            #validation
            learner = meta_model.clone(first_order = False)
            task = val_taskset.sample() # sample 一組 [50,3,84,84],[50]
            query_loss, query_acc = adaptation(task, learner,fas, loss_func, device)
            val_loss+=query_loss.item()
            val_acc +=query_acc.item()
        res["train_acc_list"].append((train_acc/meta_batch))
        res["val_acc_list"].append((val_acc/meta_batch))
        res["train_loss_list"].append((train_loss/meta_batch))
        res["val_loss_list"].append((val_loss/meta_batch))
        for p in meta_model.parameters():
            p.grad.data.mul_(1.0 / meta_batch)
        opt.step()
        
        # write log file
        out = {
            "Epoch" : epoch,
            "train accuracy" : (train_acc/meta_batch),
            "validation accuracy" : (val_acc/meta_batch),
            "train loss" : (train_loss/meta_batch),
            "validation loss" : (val_loss/meta_batch)
        }
        with open(log_file,"a") as f:
            f.write(str(out) + '\n')
            
        # print log
        if (epoch+1)%100==0:
            print(f"Epoch {epoch+1} | Train loss :{train_loss/meta_batch} | Train accuracy : {train_acc/meta_batch}")
            print(f"Epoch {epoch+1} | Validation loss :{val_loss/meta_batch} | Validation accuracy : {val_acc/meta_batch}")
        # save model
        if (epoch+1)%500==0:
            MODEL_PATH=f"./model/{model_name}"
            Path(MODEL_PATH).mkdir(parents=True, exist_ok=True)
            torch.save(meta_model,Path(MODEL_PATH)/f"model_{epoch+1}.pt")
        # early stop
        n = 750
        if len(res["val_loss_list"])>n:
            if (epoch+1)%n == 0:
                avg_val_loss = sum(res["val_loss_list"][-n:])/n
                print("Average val_loss of last 500 epoch:",avg_val_loss)
                if last_val_loss < avg_val_loss:
                    print("Early stop!!!")
                    return res
                last_val_loss = avg_val_loss
    return res

In [20]:
def test(meta_model, test_data, loss_func, device):
    test_loss = 0
    test_acc = 0
    test_sup_images = test_data['sup_images']
    test_sup_labels = test_data['sup_labels']
    test_qry_images = test_data['qry_images']
    all_pred = []
    for i in trange(len(test_sup_images)):
        learner = meta_model.clone(first_order = False)
        sup_image = torch.tensor(test_sup_images[i]).to(device)
        sup_label = torch.tensor(test_sup_labels[i]).to(device)
        qry_image = torch.tensor(test_qry_images[i]).to(device)
        for step in range(fas): # inner loop
            pred = learner(sup_image)
            train_loss = loss_func(pred, sup_label)
            learner.adapt(train_loss)
        output_pred = torch.argmax(learner(qry_image),1).reshape(-1,1).detach().cpu().numpy()
        all_pred += [item for sublist in output_pred for item in sublist]
    return all_pred

In [None]:
maml_lr = 0.5
lr = 1e-3
EPOCHS = 12000
meta_batch = 32
fas = 1
model = l2l.vision.models.ResNet12(output_size=ways).to(device)
meta_model = l2l.algorithms.MAML(model, lr=maml_lr,first_order=False)
opt = torch.optim.AdamW(meta_model.parameters(), lr=lr)
loss_func = nn.CrossEntropyLoss(reduction='mean')
results  = train(EPOCHS, meta_model, train_taskset, val_taskset, loss_func,opt, device, "resnet_task_augmentation_v2")

  1%|▏                  | 100/12000 [33:52<67:06:29, 20.30s/it]

Epoch 100 | Train loss :1.570712849497795 | Train accuracy : 0.32874999032355845
Epoch 100 | Validation loss :1.5530204735696316 | Validation accuracy : 0.34249999141320586


  2%|▎                | 200/12000 [1:07:39<66:30:00, 20.29s/it]

Epoch 200 | Train loss :1.38582044839859 | Train accuracy : 0.40624998742714524
Epoch 200 | Validation loss :1.4666130542755127 | Validation accuracy : 0.3537499886006117


  2%|▍                | 300/12000 [1:41:26<65:55:01, 20.28s/it]

Epoch 300 | Train loss :1.2795218210667372 | Train accuracy : 0.47124998830258846
Epoch 300 | Validation loss :1.4203955307602882 | Validation accuracy : 0.40374998888000846


  3%|▌                | 400/12000 [2:15:14<65:34:45, 20.35s/it]

Epoch 400 | Train loss :1.3099637310951948 | Train accuracy : 0.43874998949468136
Epoch 400 | Validation loss :1.4088043067604303 | Validation accuracy : 0.41249998891726136


  4%|▋                | 500/12000 [2:49:09<65:07:13, 20.39s/it]

Epoch 500 | Train loss :1.1541109550744295 | Train accuracy : 0.5374999875202775
Epoch 500 | Validation loss :1.3580290339887142 | Validation accuracy : 0.3962499920744449


  5%|▊                | 600/12000 [3:23:08<64:40:49, 20.43s/it]

Epoch 600 | Train loss :1.1392588838934898 | Train accuracy : 0.5474999854341149
Epoch 600 | Validation loss :1.3124375138431787 | Validation accuracy : 0.4587499899789691


  6%|▉                | 700/12000 [3:57:10<64:00:15, 20.39s/it]

Epoch 700 | Train loss :1.095608750358224 | Train accuracy : 0.566249992698431
Epoch 700 | Validation loss :1.2518904879689217 | Validation accuracy : 0.4749999865889549


  7%|█▏               | 800/12000 [4:31:14<63:35:37, 20.44s/it]

Epoch 800 | Train loss :1.0472336914390326 | Train accuracy : 0.5799999828450382
Epoch 800 | Validation loss :1.2039296869188547 | Validation accuracy : 0.5237499866634607


  8%|█▎               | 900/12000 [5:05:19<63:06:01, 20.47s/it]

Epoch 900 | Train loss :1.0173608688637614 | Train accuracy : 0.5949999829754233
Epoch 900 | Validation loss :1.2143363747745752 | Validation accuracy : 0.47624998819082975


  8%|█▎              | 1000/12000 [5:39:26<62:38:02, 20.50s/it]

Epoch 1000 | Train loss :0.9909851755946875 | Train accuracy : 0.5924999862909317
Epoch 1000 | Validation loss :1.2419708389788866 | Validation accuracy : 0.4862499851733446


  9%|█▍              | 1100/12000 [6:13:33<61:56:44, 20.46s/it]

Epoch 1100 | Train loss :0.8984892787411809 | Train accuracy : 0.657499979250133
Epoch 1100 | Validation loss :1.1141134984791279 | Validation accuracy : 0.5362499877810478


 10%|█▌              | 1200/12000 [6:47:41<61:27:36, 20.49s/it]

Epoch 1200 | Train loss :0.9865195043385029 | Train accuracy : 0.6187499845400453
Epoch 1200 | Validation loss :1.2251667939126492 | Validation accuracy : 0.4912499859929085


 11%|█▋              | 1300/12000 [7:21:49<60:49:32, 20.46s/it]

Epoch 1300 | Train loss :0.8876569923013449 | Train accuracy : 0.6449999818578362
Epoch 1300 | Validation loss :1.0850266385823488 | Validation accuracy : 0.5637499820441008


 12%|█▊              | 1400/12000 [7:55:59<60:21:31, 20.50s/it]

Epoch 1400 | Train loss :0.9475674843415618 | Train accuracy : 0.6037499820813537
Epoch 1400 | Validation loss :1.0870873052626848 | Validation accuracy : 0.5549999866634607


 12%|██              | 1500/12000 [8:31:16<60:01:08, 20.58s/it]

Epoch 1500 | Train loss :0.8623901894316077 | Train accuracy : 0.6674999836832285
Epoch 1500 | Validation loss :1.0550154130905867 | Validation accuracy : 0.5399999879300594
Average val_loss of last 500 epoch: 1.1692930254812042


 13%|██▏             | 1600/12000 [9:05:27<59:12:14, 20.49s/it]

Epoch 1600 | Train loss :0.857006604783237 | Train accuracy : 0.662499981932342
Epoch 1600 | Validation loss :1.14059142395854 | Validation accuracy : 0.5412499848753214


 14%|██▎             | 1700/12000 [9:40:10<58:38:54, 20.50s/it]

Epoch 1700 | Train loss :0.8806136418133974 | Train accuracy : 0.6537499842233956
Epoch 1700 | Validation loss :1.0537420641630888 | Validation accuracy : 0.5599999874830246


 15%|██▎            | 1800/12000 [10:14:22<58:06:12, 20.51s/it]

Epoch 1800 | Train loss :0.8286666385829449 | Train accuracy : 0.678749980404973
Epoch 1800 | Validation loss :0.9949292801320553 | Validation accuracy : 0.5999999800696969


 16%|██▍            | 1900/12000 [10:48:33<57:34:25, 20.52s/it]

Epoch 1900 | Train loss :0.8331728307530284 | Train accuracy : 0.6724999789148569
Epoch 1900 | Validation loss :1.0567252319306135 | Validation accuracy : 0.5662499880418181


 17%|██▌            | 2000/12000 [11:22:45<57:01:18, 20.53s/it]

Epoch 2000 | Train loss :0.7290941849350929 | Train accuracy : 0.702499981969595
Epoch 2000 | Validation loss :0.9380176682025194 | Validation accuracy : 0.618749987334013


 18%|██▋            | 2100/12000 [11:56:57<56:25:37, 20.52s/it]

Epoch 2100 | Train loss :0.8267566082067788 | Train accuracy : 0.683749983087182
Epoch 2100 | Validation loss :1.0906707849353552 | Validation accuracy : 0.5599999874830246


 18%|██▊            | 2200/12000 [12:31:06<55:39:43, 20.45s/it]

Epoch 2200 | Train loss :0.7567527676001191 | Train accuracy : 0.7124999798834324
Epoch 2200 | Validation loss :1.0072479220107198 | Validation accuracy : 0.6162499841302633


 19%|██▊            | 2250/12000 [12:48:10<55:28:20, 20.48s/it]

Average val_loss of last 500 epoch: 1.0531685971257587


 19%|██▉            | 2300/12000 [13:05:14<55:13:02, 20.49s/it]

Epoch 2300 | Train loss :0.7463441910222173 | Train accuracy : 0.7037499863654375
Epoch 2300 | Validation loss :0.9180787429213524 | Validation accuracy : 0.6412499826401472


 20%|███            | 2400/12000 [13:39:21<54:32:18, 20.45s/it]

Epoch 2400 | Train loss :0.696585207246244 | Train accuracy : 0.7437499761581421
Epoch 2400 | Validation loss :1.0358286164700985 | Validation accuracy : 0.5724999913945794


 21%|███▏           | 2500/12000 [14:13:26<53:57:50, 20.45s/it]

Epoch 2500 | Train loss :0.7615073553752154 | Train accuracy : 0.6962499786168337
Epoch 2500 | Validation loss :0.9732493152841926 | Validation accuracy : 0.5924999872222543


 22%|███▎           | 2600/12000 [14:47:31<53:24:11, 20.45s/it]

Epoch 2600 | Train loss :0.7009940398856997 | Train accuracy : 0.7262499816715717
Epoch 2600 | Validation loss :0.9650092059746385 | Validation accuracy : 0.6037499867379665


 22%|███▍           | 2700/12000 [15:21:34<52:49:56, 20.45s/it]

Epoch 2700 | Train loss :0.7028951831161976 | Train accuracy : 0.722499979659915
Epoch 2700 | Validation loss :0.9412793992087245 | Validation accuracy : 0.6287499833852053


 23%|███▌           | 2800/12000 [15:55:37<52:17:51, 20.46s/it]

Epoch 2800 | Train loss :0.6796522475779057 | Train accuracy : 0.7312499769032001
Epoch 2800 | Validation loss :0.9909855024889112 | Validation accuracy : 0.6112499851733446


 24%|███▋           | 2900/12000 [16:29:40<51:40:03, 20.44s/it]

Epoch 2900 | Train loss :0.6539186351001263 | Train accuracy : 0.7687499802559614
Epoch 2900 | Validation loss :0.9820780791342258 | Validation accuracy : 0.6137499855831265


 25%|███▊           | 3000/12000 [17:03:41<50:58:26, 20.39s/it]

Epoch 3000 | Train loss :0.6432654019445181 | Train accuracy : 0.7637499738484621
Epoch 3000 | Validation loss :0.9600678272545338 | Validation accuracy : 0.6087499791756272
Average val_loss of last 500 epoch: 0.9890074621954312


 26%|███▉           | 3100/12000 [17:37:43<50:27:49, 20.41s/it]

Epoch 3100 | Train loss :0.6218223357573152 | Train accuracy : 0.762499975040555
Epoch 3100 | Validation loss :0.9648890178650618 | Validation accuracy : 0.587499983375892


 27%|████           | 3200/12000 [18:11:44<49:53:13, 20.41s/it]

Epoch 3200 | Train loss :0.562789237126708 | Train accuracy : 0.7837499808520079
Epoch 3200 | Validation loss :1.013189798220992 | Validation accuracy : 0.5849999850615859


 28%|████▏          | 3300/12000 [18:45:45<49:22:12, 20.43s/it]

Epoch 3300 | Train loss :0.5091496123932302 | Train accuracy : 0.8124999795109034
Epoch 3300 | Validation loss :0.9198005357757211 | Validation accuracy : 0.663749978877604


 28%|████▎          | 3400/12000 [19:19:46<48:42:19, 20.39s/it]

Epoch 3400 | Train loss :0.6262475796975195 | Train accuracy : 0.7649999763816595
Epoch 3400 | Validation loss :0.9794522617012262 | Validation accuracy : 0.617499983869493


 29%|████▍          | 3500/12000 [19:53:43<48:07:10, 20.38s/it]

Epoch 3500 | Train loss :0.6467301771044731 | Train accuracy : 0.7487499797716737
Epoch 3500 | Validation loss :0.9574658563360572 | Validation accuracy : 0.593749986961484


 30%|████▌          | 3600/12000 [20:28:07<50:38:09, 21.70s/it]

Epoch 3600 | Train loss :0.6269261424895376 | Train accuracy : 0.7637499831616879
Epoch 3600 | Validation loss :0.9270792594179511 | Validation accuracy : 0.6374999834224582


 31%|████▋          | 3700/12000 [21:02:32<46:52:30, 20.33s/it]

Epoch 3700 | Train loss :0.5636445572599769 | Train accuracy : 0.7912499811500311
Epoch 3700 | Validation loss :0.9328939588740468 | Validation accuracy : 0.6437499783933163


 31%|████▋          | 3750/12000 [21:19:29<46:36:32, 20.34s/it]

Average val_loss of last 500 epoch: 0.9348789933913698


 32%|████▊          | 3800/12000 [21:36:27<46:19:04, 20.33s/it]

Epoch 3800 | Train loss :0.5582282887771726 | Train accuracy : 0.7737499745562673
Epoch 3800 | Validation loss :0.8357252394780517 | Validation accuracy : 0.6612499784678221


 32%|████▉          | 3900/12000 [22:10:23<45:46:10, 20.34s/it]

Epoch 3900 | Train loss :0.6313612057128921 | Train accuracy : 0.7687499774619937
Epoch 3900 | Validation loss :0.91527886595577 | Validation accuracy : 0.6299999840557575


 33%|█████          | 4000/12000 [22:44:18<45:15:11, 20.36s/it]

Epoch 4000 | Train loss :0.4645992508158088 | Train accuracy : 0.8162499740719795
Epoch 4000 | Validation loss :0.8625044822692871 | Validation accuracy : 0.6562499888241291


 34%|█████          | 4096/12000 [23:16:54<44:43:25, 20.37s/it]

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w
df1 = pd.DataFrame(results["train_acc_list"],columns=['train_accuracy'])
df2 = pd.DataFrame([i for i in results["val_acc_list"]],columns=['val_accuracy'])
df = pd.concat([df1,df2],axis=1)
df.plot()

In [None]:
df1 = pd.DataFrame(results["train_loss_list"][1000:],columns=['train_loss'])
df2 = pd.DataFrame([i for i in results["val_loss_list"][1000:]],columns=['val_loss'])
df = pd.concat([df1,df2],axis=1)
df.plot()