In [1]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
import copy
from collections import OrderedDict

class AdapterLayer(nn.Module):
    def __init__(self, input_size, reduction_factor):
        super(AdapterLayer, self).__init__()
        self.skip_adapter = False
        self.adapter = nn.Sequential(nn.Linear(input_size, input_size//reduction_factor),
                                     nn.ReLU(),
                                     nn.Linear(input_size//reduction_factor, input_size))
        self.adapter.apply(self.init_weights)

    def init_weights(self, m, std = 1e-2):
        if type(m) == nn.Linear:
            torch.nn.init.normal_(m.weight, std = std)
            torch.nn.init.normal_(m.bias, std = std)
            m.weight.data = torch.clamp(m.weight.data, min = -2*std, max = 2*std)
            m.bias.data = torch.clamp(m.bias.data, min = -2*std, max = 2*std)
    
    def forward(self, X):
        if self.skip_adapter:
            return X
        else:
            return self.adapter(X) + X

In [2]:
class Adapter_wrapper(nn.Module):
    def __init__(self, backbone, reduction_factor = 32):
        super(Adapter_wrapper, self).__init__()
        self.backbone = backbone
        self.backbone_layers = len(self.backbone.transformer.h)
        hidden_size = self.backbone.config.hidden_size
        
        self.adapters = nn.ModuleList([AdapterLayer(hidden_size, reduction_factor).to(backbone.device) for _ in range(self.backbone_layers)])
        
    def forward(self, *args, **kwargs):
        self.attach_adapters()
        outputs = self.backbone.forward(*args, **kwargs)
        self.detach_adapters()
        return outputs
    
    def attach_adapters(self):
        for n in range(self.backbone_layers):
            self.backbone.transformer.h[n].mlp = nn.Sequential(OrderedDict([('MLP', self.backbone.transformer.h[n].mlp),
                                                    ('Adapter', self.adapters[n])]))
            
    def detach_adapters(self):
        for n in range(self.backbone_layers):
            self.backbone.transformer.h[n].mlp = self.backbone.transformer.h[n].mlp[0]

In [14]:
gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2').to('cuda')

#value_model = Adapter_wrapper(gpt2_model)

#optimizer = torch.optim.Adam(gpt2_model.parameters(), lr=3e-4)

In [15]:
import numpy as np
model_parameters = filter(lambda p: p.requires_grad, gpt2_model.parameters())
all_params = sum([np.prod(p.size()) for p in model_parameters])

model_parameters = filter(lambda p: p.requires_grad, gpt2_model.lm_head.parameters())
head_params = sum([np.prod(p.size()) for p in model_parameters])

head_params/all_params

0.31016904172658316

In [17]:
all_params

124439808

In [10]:
124439808
120 000000

In [18]:
head_params

38597376