In [17]:
"""
Adapted from https://github.com/lukemelas/simple-bert
"""
 
import numpy as np
from torch import nn
from torch import Tensor 
from torch.nn import functional as F

PEFT = 0

def split_last(x, shape):
    "split the last dimension to given shape"
    shape = list(shape)
    assert shape.count(-1) <= 1
    if -1 in shape:
        shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
    return x.view(*x.size()[:-1], *shape)


def merge_last(x, n_dims):
    "merge the last n_dims to a dimension"
    s = x.size()
    assert n_dims > 1 and n_dims < len(s)
    return x.view(*s[:-n_dims], -1)

class TrainableEltwiseLayer(nn.Module):
    def __init__(self, n, h, w):
        super(TrainableEltwiseLayer, self).__init__()
        self.weights = nn.Parameter(torch.Tensor(1, n, h, w))  # define the trainable parameter

    def forward(self, x):
        # assuming x is of size b-1-h-w
        return x * self.weights  # element-wise multiplication


class MultiHeadedSelfAttention(nn.Module):
    """Multi-Headed Dot Product Attention"""
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.proj_q = nn.Linear(dim, dim)
        self.proj_k = nn.Linear(dim, dim)
        self.proj_v = nn.Linear(dim, dim)
        
        head_dim = dim // num_heads

#         self.lk = TrainableEltwiseLayer(1, head_dim, 1)
#         nn.init.ones_(self.lk.weights)
#         self.lv = TrainableEltwiseLayer(1, 1, head_dim)
#         nn.init.ones_(self.lv.weights)

        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads
        self.scores = None # for visualization

    def forward(self, x, mask):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        
        #PEFT
        #if PEFT:
#         k = self.lk(k.transpose(-2, -1))
#         v = self.lv(v)
#         scores = q @ k / np.sqrt(k.size(-1))
#         else:
#           # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(F.softmax(scores, dim=-1))
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = (scores @ v).transpose(1, 2).contiguous()
        # -merge-> (B, S, D)
        h = merge_last(h, 2)
        self.scores = scores
        return h


class PositionWiseFeedForward(nn.Module):
    """FeedForward Neural Networks for each position"""
    def __init__(self, dim, ff_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, ff_dim)
        #PEFT
        #if PEFT:
#         self.fcpeft = TrainableEltwiseLayer(1, 1, 3072)
#         nn.init.ones_(self.fcpeft.weights)
        
        self.fc2 = nn.Linear(ff_dim, dim)

    def forward(self, x):
        # (B, S, D) -> (B, S, D_ff) -> (B, S, D)
        # if PEFT:
        #     return self.fc2(self.fcpeft(F.gelu(self.fc1(x))))
        x = F.gelu(self.fc1(x))
        #if PEFT:
#         x = self.fcpeft(x)
#         x = x.squeeze(0)
        x = self.fc2(x)
        return x


class Block(nn.Module):
    """Transformer Block"""
    def __init__(self, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout)
        self.proj = nn.Linear(dim, dim)
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.pwff = PositionWiseFeedForward(dim, ff_dim)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask):
        h = self.drop(self.proj(self.attn(self.norm1(x), mask)))
        x = x + h
        h = self.drop(self.pwff(self.norm2(x)))
        x = x + h
        return x


class Transformer(nn.Module):
    """Transformer with Self-Attentive Blocks"""
    def __init__(self, num_layers, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x


In [18]:
"""model.py - Model and module class for ViT.
   They are built to mirror those in the official Jax implementation.
"""

from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F

from pytorch_pretrained_vit.utils import load_pretrained_weights, as_tuple
from pytorch_pretrained_vit.configs import PRETRAINED_MODELS


class PositionalEmbedding1D(nn.Module):
    """Adds (optionally learned) positional embeddings to the inputs."""

    def __init__(self, seq_len, dim):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim))
    
    def forward(self, x):
        """Input has shape `(batch_size, seq_len, emb_dim)`"""
        return x + self.pos_embedding


class ViT(nn.Module):
    """
    Args:
        name (str): Model name, e.g. 'B_16'
        pretrained (bool): Load pretrained weights
        in_channels (int): Number of channels in input data
        num_classes (int): Number of classes, default 1000

    References:
        [1] https://openreview.net/forum?id=YicbFdNTTy
    """

    def __init__(
        self, 
        name: Optional[str] = None, 
        pretrained: bool = False, 
        patches: int = 16,
        dim: int = 768,
        ff_dim: int = 3072,
        num_heads: int = 12,
        num_layers: int = 12,
        attention_dropout_rate: float = 0.0,
        dropout_rate: float = 0.1,
        representation_size: Optional[int] = None,
        load_repr_layer: bool = False,
        classifier: str = 'token',
        positional_embedding: str = '1d',
        in_channels: int = 3, 
        image_size: Optional[int] = None,
        num_classes: Optional[int] = None,
    ):
        super().__init__()

        # Configuration
        if name is None:
            check_msg = 'must specify name of pretrained model'
            assert not pretrained, check_msg
            #assert not resize_positional_embedding, check_msg
            if num_classes is None:
                num_classes = 1000
            if image_size is None:
                image_size = 384
        else:  # load pretrained model
            assert name in PRETRAINED_MODELS.keys(), \
                'name should be in: ' + ', '.join(PRETRAINED_MODELS.keys())
            config = PRETRAINED_MODELS[name]['config']
            patches = config['patches']
            dim = config['dim']
            ff_dim = config['ff_dim']
            num_heads = config['num_heads']
            num_layers = config['num_layers']
            attention_dropout_rate = config['attention_dropout_rate']
            dropout_rate = config['dropout_rate']
            representation_size = config['representation_size']
            classifier = config['classifier']
            if image_size is None:
                image_size = PRETRAINED_MODELS[name]['image_size']
            if num_classes is None:
                num_classes = PRETRAINED_MODELS[name]['num_classes']
        self.image_size = image_size                

        # Image and patch sizes
        h, w = as_tuple(image_size)  # image sizes
        fh, fw = as_tuple(patches)  # patch sizes
        gh, gw = h // fh, w // fw  # number of patches
        seq_len = gh * gw

        # Patch embedding
        self.patch_embedding = nn.Conv2d(in_channels, dim, kernel_size=(fh, fw), stride=(fh, fw))

        # Class token
        if classifier == 'token':
            self.class_token = nn.Parameter(torch.zeros(1, 1, dim))
            seq_len += 1
        
        # Positional embedding
        if positional_embedding.lower() == '1d':
            self.positional_embedding = PositionalEmbedding1D(seq_len, dim)
        else:
            raise NotImplementedError()
        
        # Transformer
        self.transformer = Transformer(num_layers=num_layers, dim=dim, num_heads=num_heads, 
                                       ff_dim=ff_dim, dropout=dropout_rate)
        
        # Representation layer
        if representation_size and load_repr_layer:
            self.pre_logits = nn.Linear(dim, representation_size)
            pre_logits_size = representation_size
        else:
            pre_logits_size = dim

        # Classifier head
        self.norm = nn.LayerNorm(pre_logits_size, eps=1e-6)
        self.fc = nn.Linear(pre_logits_size, num_classes)

        # Initialize weights
        self.init_weights()
        
        # Load pretrained model
        if pretrained:
            pretrained_num_channels = 3
            pretrained_num_classes = PRETRAINED_MODELS[name]['num_classes']
            pretrained_image_size = PRETRAINED_MODELS[name]['image_size']
            load_pretrained_weights(
                self, name, 
                load_first_conv=(in_channels == pretrained_num_channels),
                load_fc=(num_classes == pretrained_num_classes),
                load_repr_layer=load_repr_layer,
                resize_positional_embedding=(image_size != pretrained_image_size),
            )
        
    @torch.no_grad()
    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)  # _trunc_normal(m.weight, std=0.02)  # from .initialization import _trunc_normal
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)  # nn.init.constant(m.bias, 0)
        self.apply(_init)
        nn.init.constant_(self.fc.weight, 0)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.normal_(self.positional_embedding.pos_embedding, std=0.02)  # _trunc_normal(self.positional_embedding.pos_embedding, std=0.02)
        nn.init.constant_(self.class_token, 0)

    def forward(self, x):
        """Breaks image into patches, applies transformer, applies MLP head.

        Args:
            x (tensor): `b,c,fh,fw`
        """
        b, c, fh, fw = x.shape
        x = self.patch_embedding(x)  # b,d,gh,gw
        x = x.flatten(2).transpose(1, 2)  # b,gh*gw,d
        if hasattr(self, 'class_token'):
            x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1)  # b,gh*gw+1,d
        if hasattr(self, 'positional_embedding'): 
            x = self.positional_embedding(x)  # b,gh*gw+1,d 
        x = self.transformer(x)  # b,gh*gw+1,d
        if hasattr(self, 'pre_logits'):
            x = self.pre_logits(x)
            x = torch.tanh(x)
        if hasattr(self, 'fc'):
            x = self.norm(x)[:, 0]  # b,d
            x = self.fc(x)  # b,num_classes
        return x



In [3]:
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm

random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
from easyfsl.datasets import CUB
from torch.utils.data import DataLoader
from torchvision import transforms

batch_size = 128
n_workers = 4

train_set = CUB(split="train", transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((384, 384))]), training=True)
train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    num_workers=n_workers,
    pin_memory=True,
    shuffle=True,
)

In [5]:
num_classes = len(set(train_set.get_labels()))
num_classes

140

In [19]:
class CUBModel(nn.Module):
    def __init__(self,num_classes):
        super().__init__()
        self.vit = ViT('B_32_imagenet1k',pretrained=False)
        self.linear_last = nn.Linear(1000,num_classes)
        
    def forward(self, xb):
        out = self.vit(xb)
        out = F.relu(out)
        out = self.linear_last(out)
        return out

In [7]:
model = CUBModel(num_classes = num_classes)

In [11]:
from easyfsl.samplers import TaskSampler

random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

n_way = 5
n_shot = 5
n_query = 10

test_set = CUB(split="test", transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((384, 384))]), training=False)

n_test_tasks = 50
test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))
def evaluate_query(model, images, labels):
    model.to(device)
    model.eval()
    out = model(images) 
    loss = F.cross_entropy(out, labels)
    acc = accuracy(out, labels)
    return acc
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [9]:
batch_list = []
for batch in test_loader:
    batch_list.append(batch)

In [12]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
n_gpu = torch.cuda.device_count()

#model.to(device)

print("Number of GPUs: "+str(n_gpu))

train_losses = []
n_epochs = 300

peft_list = []

PATH = 'vit_32_384_CUB.pth'

from collections import OrderedDict
state_dict = torch.load(PATH)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove module.
    new_state_dict[name] = v

for batch in batch_list:
    support_images, support_labels, query_images, query_labels,_ = batch
    support_images = support_images.to(device=device, dtype=torch.float32)
    support_labels = support_labels.to(device=device, dtype=torch.long)
    query_images = query_images.to(device=device, dtype=torch.float32)
    query_labels = query_labels.to(device=device, dtype=torch.long)
    
    #PEFT
    PEFT = 1
    num_classes = 5
    model_peft = CUBModel(num_classes = num_classes)
    print(model_peft.load_state_dict(new_state_dict, strict=False))
    if PEFT:
        count = 0
        for name, param in model_peft.named_parameters():
                if 'lk' in name or 'lv' in name or 'fcpeft' in name or 'linear_last' in name:
                    param.requires_grad = True
                    count += 1
                else:
                    param.requires_grad = False
        print(count)
    print("Trainable parameters are (PEFT):" + str(count_parameters(model_peft)))
    for model in [model_peft]:
        
        model.to(device)
        model.train()

        criterion = nn.CrossEntropyLoss()

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)

        total_loss = 0
        model.to(device)
        model.train()
        for epochs in range(n_epochs):
            
            optimizer.zero_grad()
            outputs = model(support_images)
            
            loss = criterion(outputs.squeeze(),support_labels)
            
            if n_gpu>1:
                loss = loss.mean()
            
            train_loss = loss.item()
            total_loss += train_loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            if epochs%100 == 0:
                print("Training Loss is " + str(total_loss))
        print("Training Loss is " + str(total_loss))
        results = evaluate_query(model, query_images, query_labels)
        print(results.item())
        peft_list.append(results.item())

Number of GPUs: 10
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.attn.lk.weights', 'vit.transformer.blocks.6.attn.lv.weights', 'vit.transformer.blocks.6.pwff.fcpeft.weights

Training Loss is 64.87802556157112
Training Loss is 77.67067790776491
Training Loss is 86.54957052320242
0.9399999976158142
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.at

Training Loss is 62.45917774736881
Training Loss is 74.90675653517246
Training Loss is 83.57948070764542
0.8799999952316284
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.at

Training Loss is 69.58752398192883
Training Loss is 81.8641387373209
Training Loss is 90.33695640414953
0.9800000190734863
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.att

Training Loss is 55.93259063363075
Training Loss is 64.98009048774838
Training Loss is 71.26551017165184
0.9800000190734863
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.at

Training Loss is 58.16931040585041
Training Loss is 69.04219924658537
Training Loss is 76.45941636711359
0.9800000190734863
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.at

Training Loss is 56.95721662044525
Training Loss is 67.5614108517766
Training Loss is 74.89263455942273
0.9399999976158142
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.att

Training Loss is 66.04757411777973
Training Loss is 78.92369071394205
Training Loss is 87.66179888695478
0.7799999713897705
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.at

Training Loss is 61.906217232346535
Training Loss is 73.25086195766926
Training Loss is 81.04978217929602
0.7799999713897705
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.a

Training Loss is 60.036123156547546
Training Loss is 69.05720227956772
Training Loss is 75.22141048684716
0.9399999976158142
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transformer.blocks.4.attn.lv.weights', 'vit.transformer.blocks.4.pwff.fcpeft.weights', 'vit.transformer.blocks.5.attn.lk.weights', 'vit.transformer.blocks.5.attn.lv.weights', 'vit.transformer.blocks.5.pwff.fcpeft.weights', 'vit.transformer.blocks.6.a

Training Loss is 55.07974925637245
Training Loss is 65.29332509636879
Training Loss is 72.32929148897529
0.8999999761581421


In [15]:
peft_mean = sum(peft_list)/len(peft_list)
peft_mean

0.9139999973773957

In [16]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp
get_n_params(model_peft)

88340597

In [20]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
n_gpu = torch.cuda.device_count()

#model.to(device)

print("Number of GPUs: "+str(n_gpu))

train_losses = []
n_epochs = 300

peft_list = []

PATH = 'vit_32_384_CUB.pth'

from collections import OrderedDict
state_dict = torch.load(PATH)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove module.
    new_state_dict[name] = v

for batch in batch_list:
    support_images, support_labels, query_images, query_labels,_ = batch
    support_images = support_images.to(device=device, dtype=torch.float32)
    support_labels = support_labels.to(device=device, dtype=torch.long)
    query_images = query_images.to(device=device, dtype=torch.float32)
    query_labels = query_labels.to(device=device, dtype=torch.long)
    
    #Linear Probe
    linear_probe = 1
    model_linear = CUBModel(num_classes = num_classes)
    print(model_linear.load_state_dict(new_state_dict, strict=False))
    if linear_probe:
        count = 0
        for name, param in model_linear.named_parameters():
            if 'linear_last' in name:
                    param.requires_grad = True
                    count += 1
            else:
                  param.requires_grad = False
        print(count)
    print("Trainable parameters are (linear Probe):" + str(count_parameters(model_linear)))
    for model in [model_linear]:
        
        model.to(device)
        model.train()

        criterion = nn.CrossEntropyLoss()

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)

        total_loss = 0
        model.to(device)
        model.train()
        for epochs in range(n_epochs):
            
            optimizer.zero_grad()
            outputs = model(support_images)
            
            loss = criterion(outputs.squeeze(),support_labels)
            
            if n_gpu>1:
                loss = loss.mean()
            
            train_loss = loss.item()
            total_loss += train_loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            if epochs%100 == 0:
                print("Training Loss is " + str(total_loss))
        print("Training Loss is " + str(total_loss))
        results = evaluate_query(model, query_images, query_labels)
        print(results.item())
        peft_list.append(results.item())

Number of GPUs: 10
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
2
Trainable parameters are (linear Probe):5005
Training Loss is 1.7479355335235596
Training Loss is 52.41307633370161
Training Loss is 60.97144044190645
Training Loss is 67.06167459487915
0.8799999952316284
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
2
Trainable parameters are (linear Probe):5005
Training Loss is 2.009232521057129
Training Loss is 58.70551225543022
Training Loss is 69.4730398580432
Training Loss is 77.13441551476717
0.7400000095367432
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
2
Trainable parameters are (linear Probe):5005
Training Loss is 1.9351396560668945
Training Loss is 65.16562560200691
Training Loss is 77.23303701728582
Training Loss is 85.63835795968771


Training Loss is 69.70986123383045
Training Loss is 83.0832047238946
Training Loss is 92.63710507750511
0.9399999976158142
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
2
Trainable parameters are (linear Probe):5005
Training Loss is 2.248021125793457
Training Loss is 77.57202236354351
Training Loss is 94.8772085532546
Training Loss is 107.23877052217722
0.9800000190734863
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
2
Trainable parameters are (linear Probe):5005
Training Loss is 1.909152626991272
Training Loss is 64.17970290780067
Training Loss is 76.20636510848999
Training Loss is 84.75669967383146
1.0
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
2
Trainable parameters are (linear Probe):5005
Training Loss is 1.8528987169265747
Training Loss is

Training Loss is 62.99327786266804
Training Loss is 76.55273133516312
Training Loss is 86.20826017856598
0.8999999761581421


In [21]:
peft_mean = sum(peft_list)/len(peft_list)
peft_mean

0.9235999965667725

In [38]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp
get_n_params(model_linear)

88302197

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
n_gpu = torch.cuda.device_count()

#model.to(device)

print("Number of GPUs: "+str(n_gpu))

train_losses = []
n_epochs = 300

peft_list = []

PATH = 'vit_32_384_CUB.pth'

from collections import OrderedDict
state_dict = torch.load(PATH)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove module.
    new_state_dict[name] = v

for batch in batch_list:
    support_images, support_labels, query_images, query_labels,_ = batch
    support_images = support_images.to(device=device, dtype=torch.float32)
    support_labels = support_labels.to(device=device, dtype=torch.long)
    query_images = query_images.to(device=device, dtype=torch.float32)
    query_labels = query_labels.to(device=device, dtype=torch.long)
    
    model_full = CUBModel(num_classes = num_classes)
    print(model_full.load_state_dict(new_state_dict, strict=False))
    print("Trainable parameters are (linear Probe):" + str(count_parameters(model_full)))
    for model in [model_full]:
        
        model.to(device)
        model.train()

        criterion = nn.CrossEntropyLoss()

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)

        total_loss = 0
        model.to(device)
        model.train()
        for epochs in range(n_epochs):
            
            optimizer.zero_grad()
            outputs = model(support_images)
            
            loss = criterion(outputs.squeeze(),support_labels)
            
            if n_gpu>1:
                loss = loss.mean()
            
            train_loss = loss.item()
            total_loss += train_loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            if epochs%100 == 0:
                print("Training Loss is " + str(total_loss))
        print("Training Loss is " + str(total_loss))
        results = evaluate_query(model, query_images, query_labels)
        print(results.item())
        peft_list.append(results.item())

Number of GPUs: 10
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
Trainable parameters are (linear Probe):88302197
Training Loss is 2.2196006774902344
Training Loss is 4.92985132324975
Training Loss is 4.960869594098767
Training Loss is 4.985465633901185
0.9800000190734863
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
Trainable parameters are (linear Probe):88302197
Training Loss is 1.786980152130127
Training Loss is 3.9906130560848396
Training Loss is 4.020017030867166
Training Loss is 4.043470762349898
0.9399999976158142
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
Trainable parameters are (linear Probe):88302197
Training Loss is 1.9869985580444336
Training Loss is 4.324548768345267
Training Loss is 4.353567031779676
Training Loss is 4.376600469

In [None]:
peft_mean = sum(peft_list)/len(peft_list)
peft_mean

In [41]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp
get_n_params(model_full)

88302197

In [16]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
n_gpu = torch.cuda.device_count()

#model.to(device)

print("Number of GPUs: "+str(n_gpu))

train_losses = []
n_epochs = 300

acc_list = []

PATH = 'vit_32_384_CUB.pth'

from collections import OrderedDict
state_dict = torch.load(PATH)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove module.
    new_state_dict[name] = v



for batch in test_loader:
    support_images, support_labels, query_images, query_labels,_ = batch
    support_images = support_images.to(device=device, dtype=torch.float32)
    support_labels = support_labels.to(device=device, dtype=torch.long)
    query_images = query_images.to(device=device, dtype=torch.float32)
    query_labels = query_labels.to(device=device, dtype=torch.long)

    #Linear Probe
    linear_probe = 1
    model_linear = CUBModel(num_classes = num_classes)
    print(model_linear.load_state_dict(new_state_dict, strict=False))
    if linear_probe:
        count = 0
        for name, param in model_linear.named_parameters():
            if 'linear_last' in name:
                    param.requires_grad = True
                    count += 1
            else:
                  param.requires_grad = False
        print(count)
    print("Trainable parameters are (linear Probe):" + str(count_parameters(model_linear)))
    #Full model fine tuning
    model_full = CUBModel(num_classes = num_classes)
    
    print(model_full.load_state_dict(new_state_dict, strict=False))
    print("Trainable parameters are (full model):" + str(count_parameters(model_full)))
    
    #PEFT
    PEFT = 1
    num_classes = 5
    model_peft = CUBModel(num_classes = num_classes)
    print(model_peft.load_state_dict(new_state_dict, strict=False))
    if PEFT:
        count = 0
        for name, param in model_peft.named_parameters():
                if 'lk' in name or 'lv' in name or 'fcpeft' in name or 'linear_last' in name:
                    param.requires_grad = True
                    count += 1
                else:
                    param.requires_grad = False
        print(count)
    print("Trainable parameters are (PEFT):" + str(count_parameters(model_peft)))
    for model in [model_linear, model_full, model_peft]:
        
        model.to(device)
        model.train()

        criterion = nn.CrossEntropyLoss()

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)

        total_loss = 0
        model.to(device)
        model.train()
        for epochs in range(n_epochs):
            
            optimizer.zero_grad()
            outputs = model(support_images)
            
            loss = criterion(outputs.squeeze(),support_labels)
            
            if n_gpu>1:
                loss = loss.mean()
            
            train_loss = loss.item()
            total_loss += train_loss
            loss.backward()
            optimizer.step()
            scheduler.step()
            if epochs%100 == 0:
                print("Training Loss is " + str(total_loss))
        print("Training Loss is " + str(total_loss))
        results = evaluate_query(model, query_images, query_labels)
        print(results.item())
        acc_list.append(results.item())

Number of GPUs: 10
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
2
Trainable parameters are (linear Probe):140140
_IncompatibleKeys(missing_keys=['linear_last.weight', 'linear_last.bias'], unexpected_keys=['linear1.weight', 'linear1.bias'])
Trainable parameters are (full model):88437332
_IncompatibleKeys(missing_keys=['vit.transformer.blocks.0.attn.lk.weights', 'vit.transformer.blocks.0.attn.lv.weights', 'vit.transformer.blocks.0.pwff.fcpeft.weights', 'vit.transformer.blocks.1.attn.lk.weights', 'vit.transformer.blocks.1.attn.lv.weights', 'vit.transformer.blocks.1.pwff.fcpeft.weights', 'vit.transformer.blocks.2.attn.lk.weights', 'vit.transformer.blocks.2.attn.lv.weights', 'vit.transformer.blocks.2.pwff.fcpeft.weights', 'vit.transformer.blocks.3.attn.lk.weights', 'vit.transformer.blocks.3.attn.lv.weights', 'vit.transformer.blocks.3.pwff.fcpeft.weights', 'vit.transformer.blocks.4.attn.lk.weights', 'vit.transf

AttributeError: 'MultiHeadedSelfAttention' object has no attribute 'lk'

In [89]:
lin_list = []
full_list = []
peft_list = []
for i in range(len(acc_list)):
    if i % 3 == 0:
        lin_list.append(acc_list[i])
    elif i % 3 == 1:
        full_list.append(acc_list[i])
    elif i % 3 == 2:
        peft_list.append(acc_list[i])

lin_mean = sum(lin_list)/len(lin_list)
full_mean = sum(full_list)/len(full_list)
peft_mean = sum(peft_list)/len(peft_list)
print(lin_mean, full_mean, peft_mean)

0.9296153783798218 0.9298039183897131 0.9294117583948023


In [88]:
len(acc_list)

154

In [None]:
lin_list = []
full_list = []
peft_list = []
for i in range(len(acc_list)):
    if i % 3 == 0:
        lin_list.append(acc_list[i])
    elif i % 3 == 1:
        full_list.append(acc_list[i])
    elif i % 3 == 2:
        peft_list.append(acc_list[i])

lin_mean = sum(lin_list)/len(lin_list)
full_mean = sum(full_list)/len(full_list)
peft_mean = sum(peft_list)/len(peft_list)
print(lin_mean, full_mean, peft_mean

In [11]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [28]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [30]:
get_n_params(model_peft)

204257909

In [26]:
count_parameters(model_full)

88302197

In [27]:
count_parameters(model_peft)

115960717