In [1]:
with open('shakespear_dat.txt', 'r') as f:
    dat = f.read()
chars = sorted(list(set(dat)))
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda i: ''.join([itos[l] for l in i])
import torch, math
data = torch.tensor(encode(dat))
device = torch.device('cuda:0')
# train test split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [2]:
# Hyperparameters
block_size = 8
batch_size = 2048
val_batch_size = 2048
vocab_size = len(chars)
emb_size = 32
new_emb_size = 64
num_small_layers = 4
multi_heads = 2
num_large_layers = 8

In [3]:
torch.manual_seed(1337)
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size if split=='train' else val_batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [4]:
class AttentionHead(torch.nn.Module):
    def __init__(self, big=False):
        super(AttentionHead, self).__init__()
        self.k = torch.nn.Linear(new_emb_size if big else emb_size, new_emb_size if big else emb_size, bias=False)
        self.q = torch.nn.Linear(new_emb_size if big else emb_size, new_emb_size if big else emb_size, bias=False)
        self.v = torch.nn.Linear(new_emb_size if big else emb_size, new_emb_size if big else emb_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size)))
    def forward(self, e):
        keys = self.k(e)
        queries = self.q(e)
        values = self.v(e)
        ret = keys @ queries.transpose(1, 2)*(1.0/math.sqrt(keys.size(-1)))
        ret = torch.masked_fill(ret, self.tril==0, -torch.inf)
        ret = torch.softmax(ret, 2)
        ret = ret @ values
        return ret

In [5]:
class MultiHead(torch.nn.Module):
    def __init__(self, big=False):
        super(MultiHead, self).__init__()
        self.head1 = AttentionHead(big)
        self.head2 = AttentionHead(big)
        self.mh_lin = torch.nn.Linear(multi_heads*(new_emb_size if big else emb_size), new_emb_size if big else emb_size, bias=False)
        self.drop = torch.nn.Dropout(0.1)
    def forward(self, inp):
        x1 = self.head1(inp)
        x2 = self.head2(inp)
        return self.mh_lin(self.drop(torch.cat([x1,x2], dim=2))).relu()


In [6]:
class Block(torch.nn.Module):
    def __init__(self, big=False):
        super(Block, self).__init__()
        self.multihead = MultiHead(big)
        self.l_norm_1 = torch.nn.LayerNorm(new_emb_size if big else emb_size)
        self.l_norm_2 = torch.nn.LayerNorm(new_emb_size if big else emb_size)
        self.ffn = torch.nn.Linear(new_emb_size if big else emb_size, new_emb_size if big else emb_size)
        self.drop = torch.nn.Dropout(0.1)
    def forward(self, inp):
        m = self.l_norm_1(inp + self.multihead(inp))
        m = self.l_norm_2(m + self.ffn(self.drop(m)).relu())
        return m

In [7]:
class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [8]:
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, emb_size)
        self.pe = PositionalEncoding(emb_size)
        self.block1 = Block()
        self.block2 = Block()
        self.block3 = Block()
        self.block4 = Block()
        self.f_lin = torch.nn.Linear(emb_size, vocab_size)
        self.drop = torch.nn.Dropout(0.1)
    def forward(self, inp):
        e = self.embedding(inp)
        e = self.pe(e)
        m = self.block1(e)
        m = self.block2(m)
        m = self.block3(m)
        m = self.block4(m)
        r = self.f_lin(self.drop(m))
        return r

In [9]:
model = Model().to(device)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
@torch.no_grad()
def validate(mdl):
    mdl.eval()
    vx, vy = get_batch('val')
    out = mdl(vx.to(device))
    return loss(out.view(-1, 65), vy.view(-1).to(device)).item()

In [11]:
@torch.enable_grad()
def train(mdl, optim, epochs):
    ind = 0
    for _ in range(epochs):
        mdl.train()
        optim.zero_grad()
        x, y = get_batch('train')
        out = mdl(x.to(device))
        l = loss(out.view(-1, 65), y.view(-1).to(device))
        l.backward()
        optim.step()
        ind += 1
        if ind%100 == 0:
            print(l.item())
            print(f"Validation: {validate(mdl)}")

In [12]:
train(model, optimizer, 1000)

2.9411230087280273
Validation: 2.8431901931762695
2.6231515407562256
Validation: 2.4993515014648438
2.493496894836426
Validation: 2.3980233669281006
2.4104607105255127
Validation: 2.328835964202881
2.414330005645752
Validation: 2.289748191833496
2.3340442180633545
Validation: 2.2688584327697754
2.3394172191619873
Validation: 2.2560274600982666
2.3237712383270264
Validation: 2.234231472015381
2.2912347316741943
Validation: 2.2267966270446777
2.279399871826172
Validation: 2.1773111820220947


In [13]:
k_params = torch.empty((0,emb_size*emb_size)).to(device)
q_params = torch.empty((0,emb_size*emb_size)).to(device)
v_params = torch.empty((0,emb_size*emb_size)).to(device)
lin_params = torch.empty((0, 2*emb_size*emb_size)).to(device)
ffn_w_params = torch.empty((0, emb_size*emb_size)).to(device)
ffn_b_params = torch.empty((0, emb_size)).to(device)
l_norm_w_params = torch.empty((0, emb_size)).to(device)
l_norm_b_params = torch.empty((0, emb_size)).to(device)
for i in model.state_dict():
    if '.k.' in i:
        k_params = torch.cat((k_params, model.state_dict()[i].flatten().view(1,-1)), dim=0)
    elif '.q.' in i:
        q_params = torch.cat((q_params, model.state_dict()[i].flatten().view(1,-1)), dim=0)
    elif '.v.' in i:
        v_params = torch.cat((v_params, model.state_dict()[i].flatten().view(1,-1)), dim=0)
    elif '.mh_lin' in i:
        lin_params = torch.cat((lin_params, model.state_dict()[i].flatten().view(1,-1)), dim = 0)
    elif 'ffn' in i and 'weight' in i:
        ffn_w_params = torch.cat((ffn_w_params, model.state_dict()[i].flatten().view(1,-1)), dim = 0)
    elif 'ffn' in i and 'bias' in i:
        ffn_b_params = torch.cat((ffn_b_params, model.state_dict()[i].flatten().view(1,-1)), dim = 0)
    elif 'l_norm' in i and 'weight' in i:
        l_norm_w_params = torch.cat((l_norm_w_params, model.state_dict()[i].flatten().view(1,-1)), dim = 0)
    elif 'l_norm' in i and 'bias' in i:
        l_norm_b_params = torch.cat((l_norm_b_params, model.state_dict()[i].flatten().view(1,-1)), dim = 0)

In [14]:
class WideModel(torch.nn.Module):
    def __init__(self, big=True):
        super(WideModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, new_emb_size if big else emb_size)
        self.pe = PositionalEncoding(new_emb_size if big else emb_size)
        self.block1 = Block(big=big)
        self.block2 = Block(big=big)
        self.block3 = Block(big=big)
        self.block4 = Block(big=big)
        self.f_lin = torch.nn.Linear(new_emb_size if big else emb_size, vocab_size)
        self.drop = torch.nn.Dropout(0.1)
    def forward(self, inp):
        e = self.embedding(inp)
        e = self.pe(e)
        m = self.block1(e)
        m = self.block2(m)
        m = self.block3(m)
        m = self.block4(m)
        r = self.f_lin(self.drop(m))
        return r

G_zero -----------------------------------------------------------------------------------------------------------------

1  0

0  0

In [15]:
wideModel = WideModel().to(device)

In [16]:
def G_zero_emb_or_f_lin(emb):
    zero_tensor = torch.zeros(emb.shape).to(device)
    return torch.cat((emb, zero_tensor), dim=1)

In [17]:
def G_zero_1d(l_norm):
    l_zero = torch.zeros(l_norm.shape).to(device)
    return torch.cat((l_norm, l_zero), dim=0)

In [18]:
def G_zero(weight):
    zero_tensor = torch.zeros(weight.shape).to(device)
    temp_1= torch.cat((weight, zero_tensor), dim=1)
    temp_2= torch.cat((zero_tensor, zero_tensor), dim=1)
    return torch.cat((temp_1, temp_2), dim=0)

1  0

0  0

In [19]:
@torch.no_grad()
def initWideExtendGZero(has_bias=False):
    # Embedding layer
    emb = model.embedding.weight
    wide_emb = G_zero_emb_or_f_lin(emb)
    setattr(wideModel.embedding, 'weight', torch.nn.Parameter(wide_emb, requires_grad=True).to(device))
    # f_lin layer weight
    f_lin_weight = model.f_lin.weight
    wide_f_lin_weight = G_zero_emb_or_f_lin(f_lin_weight)
    setattr(wideModel.f_lin, 'weight', torch.nn.Parameter(wide_f_lin_weight, requires_grad=True).to(device))
    # f_lin layer bias
    if has_bias == True:
        f_lin_bias = model.f_lin.bias
        wide_f_lin_bias = G_zero_emb_or_f_lin(f_lin_bias)
        setattr(wideModel.f_lin, 'bias', torch.nn.Parameter(wide_f_lin_bias, requires_grad=True).to(device))

    
    ii_list = [1,2,3,4]
    for i in range(1, num_small_layers+1):
        ii = ii_list[i-1]
        small_block = getattr(model, f'block{ii}')
        wide_block = getattr(wideModel, f'block{i}')
    
        # Setting FFN Weights
        wide_ffn_weight = G_zero(small_block.ffn.weight)
        setattr(wide_block.ffn, 'weight', torch.nn.Parameter((wide_ffn_weight), requires_grad=True).to(device))
        if has_bias == True:
            wide_ffn_bias = G_zero(small_block.ffn.bias)
            setattr(wide_block.ffn, 'bias', torch.nn.Parameter((wide_ffn_bias), requires_grad=True).to(device))

        # Setting Norm Layers
        l_norm_1_weight = G_zero_1d(small_block.l_norm_1.weight)
        l_norm_2_weight = G_zero_1d(small_block.l_norm_2.weight)
        setattr(wide_block.l_norm_1, 'weight', torch.nn.Parameter((l_norm_1_weight), requires_grad=True).to(device))
        setattr(wide_block.l_norm_2, 'weight', torch.nn.Parameter((l_norm_2_weight), requires_grad=True).to(device))
        if has_bias == True:
            l_norm_1_bias = G_zero_1d(small_block.l_norm_1.bias)
            l_norm_2_bias = G_zero_1d(small_block.l_norm_2.bias)
            setattr(wide_block.l_norm_1, 'bias', torch.nn.Parameter((l_norm_1_bias), requires_grad=True).to(device))
            setattr(wide_block.l_norm_2, 'bias', torch.nn.Parameter((l_norm_2_bias), requires_grad=True).to(device))

        # Setting Multi-Head Attention
        wide_mh_lin_weight = G_zero(small_block.multihead.mh_lin.weight)
        setattr(wide_block.multihead.mh_lin, 'weight', torch.nn.Parameter((wide_mh_lin_weight), requires_grad=True))

        for h in range(1, multi_heads+1):
            head = getattr(wide_block.multihead, f'head{h}')
            small_head = getattr(small_block.multihead, f'head{h}')

            new_k = G_zero(small_head.k.weight)
            new_q = G_zero(small_head.q.weight)
            new_v = G_zero(small_head.v.weight)
            setattr(head.k, 'weight', torch.nn.Parameter((new_k), requires_grad=True).to(device))
            setattr(head.q, 'weight', torch.nn.Parameter((new_q), requires_grad=True).to(device))
            setattr(head.v, 'weight', torch.nn.Parameter((new_v), requires_grad=True).to(device))

In [20]:
initWideExtendGZero()

In [21]:
wideModel.train()
optim_b = torch.optim.Adam(params=wideModel.parameters(), lr = 1e-3)

train(wideModel, optim_b, 1000)

2.346893310546875
Validation: 2.2651448249816895
2.3022103309631348
Validation: 2.231424331665039
2.2866806983947754
Validation: 2.1899495124816895
2.241973400115967
Validation: 2.1845719814300537
2.2491180896759033
Validation: 2.164644241333008
2.24556303024292
Validation: 2.187959909439087
2.2474703788757324
Validation: 2.1449010372161865
2.2014172077178955
Validation: 2.1533243656158447
2.196962594985962
Validation: 2.146059989929199
2.1789472103118896
Validation: 2.1582231521606445


v2 -----------------------------------------------------------------------------------------------------------------

1    0

0    1

In [22]:
def G_zero_v2(weight):
    zero_tensor = torch.zeros(weight.shape).to(device)
    temp_1= torch.cat((weight, zero_tensor), dim=1)
    temp_2= torch.cat((zero_tensor, weight), dim=1)
    return torch.cat((temp_1, temp_2), dim=0)

In [23]:
wideModel = WideModel().to(device)

In [24]:
@torch.no_grad()
def initWideExtendGZeroV2(has_bias=False):
    # Embedding layer
    emb = model.embedding.weight
    wide_emb = G_zero_emb_or_f_lin(emb)
    setattr(wideModel.embedding, 'weight', torch.nn.Parameter(wide_emb, requires_grad=True).to(device))
    # f_lin layer weight
    f_lin_weight = model.f_lin.weight
    wide_f_lin_weight = G_zero_emb_or_f_lin(f_lin_weight)
    setattr(wideModel.f_lin, 'weight', torch.nn.Parameter(wide_f_lin_weight, requires_grad=True).to(device))
    # f_lin layer bias
    if has_bias == True:
        f_lin_bias = model.f_lin.bias
        wide_f_lin_bias = G_zero_emb_or_f_lin(f_lin_bias)
        setattr(wideModel.f_lin, 'bias', torch.nn.Parameter(wide_f_lin_bias, requires_grad=True).to(device))

    
    ii_list = [1,2,3,4]
    for i in range(1, num_small_layers+1):
        ii = ii_list[i-1]
        small_block = getattr(model, f'block{ii}')
        wide_block = getattr(wideModel, f'block{i}')
    
        # Setting FFN Weights
        wide_ffn_weight = G_zero_v2(small_block.ffn.weight)
        setattr(wide_block.ffn, 'weight', torch.nn.Parameter((wide_ffn_weight), requires_grad=True).to(device))
        if has_bias == True:
            wide_ffn_bias = G_zero_v2(small_block.ffn.bias)
            setattr(wide_block.ffn, 'bias', torch.nn.Parameter((wide_ffn_bias), requires_grad=True).to(device))

        # Setting Norm Layers
        l_norm_1_weight = G_zero_1d(small_block.l_norm_1.weight)
        l_norm_2_weight = G_zero_1d(small_block.l_norm_2.weight)
        setattr(wide_block.l_norm_1, 'weight', torch.nn.Parameter((l_norm_1_weight), requires_grad=True).to(device))
        setattr(wide_block.l_norm_2, 'weight', torch.nn.Parameter((l_norm_2_weight), requires_grad=True).to(device))
        if has_bias == True:
            l_norm_1_bias = G_zero_1d(small_block.l_norm_1.bias)
            l_norm_2_bias = G_zero_1d(small_block.l_norm_2.bias)
            setattr(wide_block.l_norm_1, 'bias', torch.nn.Parameter((l_norm_1_bias), requires_grad=True).to(device))
            setattr(wide_block.l_norm_2, 'bias', torch.nn.Parameter((l_norm_2_bias), requires_grad=True).to(device))

        # Setting Multi-Head Attention
        wide_mh_lin_weight = G_zero_v2(small_block.multihead.mh_lin.weight)
        setattr(wide_block.multihead.mh_lin, 'weight', torch.nn.Parameter((wide_mh_lin_weight), requires_grad=True))

        for h in range(1, multi_heads+1):
            head = getattr(wide_block.multihead, f'head{h}')
            small_head = getattr(small_block.multihead, f'head{h}')

            new_k = G_zero_v2(small_head.k.weight)
            new_q = G_zero_v2(small_head.q.weight)
            new_v = G_zero_v2(small_head.v.weight)
            setattr(head.k, 'weight', torch.nn.Parameter((new_k), requires_grad=True).to(device))
            setattr(head.q, 'weight', torch.nn.Parameter((new_q), requires_grad=True).to(device))
            setattr(head.v, 'weight', torch.nn.Parameter((new_v), requires_grad=True).to(device))

In [25]:
initWideExtendGZeroV2()

In [26]:
wideModel.train()
optim_b = torch.optim.Adam(params=wideModel.parameters(), lr = 1e-3)

train(wideModel, optim_b, 1000)

2.3463587760925293
Validation: 2.2469334602355957
2.31365704536438
Validation: 2.216831684112549
2.244840621948242
Validation: 2.202496290206909
2.2367405891418457
Validation: 2.1984591484069824
2.221161365509033
Validation: 2.1632261276245117
2.2085494995117188
Validation: 2.1445350646972656
2.1996874809265137
Validation: 2.1144328117370605
2.1736671924591064
Validation: 2.1193418502807617
2.1432104110717773
Validation: 2.089792251586914
2.1327497959136963
Validation: 2.099591016769409


In [27]:
rawWideModel = WideModel().to(device)
optim_b = torch.optim.Adam(params=rawWideModel.parameters(), lr = 1e-3)
train(rawWideModel, optim_b, 1000)

2.533486843109131
Validation: 2.471588134765625
2.388390064239502
Validation: 2.31054425239563
2.29331374168396
Validation: 2.2138044834136963
2.207923173904419
Validation: 2.1886467933654785
2.1867995262145996
Validation: 2.180229663848877
2.1649115085601807
Validation: 2.1219332218170166
2.14483380317688
Validation: 2.130613327026367
2.113114595413208
Validation: 2.1115593910217285
2.0959486961364746
Validation: 2.0923686027526855
2.085519313812256
Validation: 2.087606191635132


Become Deeper + Wider -----------------------------------------------------------------------------------------------------------------

In [28]:
class BigModel(torch.nn.Module):
    def __init__(self, big=True):
        super(BigModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, new_emb_size if big else emb_size)
        self.pe = PositionalEncoding(new_emb_size if big else emb_size)
        self.block1 = Block(big=big)
        self.block2 = Block(big=big)
        self.block3 = Block(big=big)
        self.block4 = Block(big=big)
        self.block5 = Block(big=big)
        self.block6 = Block(big=big)
        self.block7 = Block(big=big)
        self.block8 = Block(big=big)
        self.f_lin = torch.nn.Linear(new_emb_size if big else emb_size, vocab_size)
        self.drop = torch.nn.Dropout(0.1)
    def forward(self, inp):
        e = self.embedding(inp)
        e = self.pe(e)
        m = self.block1(e)
        m = self.block2(m)
        m = self.block3(m)
        m = self.block4(m)
        m = self.block5(m)
        m = self.block6(m)
        m = self.block7(m)
        m = self.block8(m)
        r = self.f_lin(self.drop(m))
        return r

In [29]:
wideBigModel = BigModel().to(device)

In [30]:
num_big_layers = 8
multi_heads = 2

In [31]:
@torch.no_grad()
def initWideExtendGZero_doubleStack(has_bias=False):
    # Embedding layer
    emb = model.embedding.weight
    wide_emb = G_zero_emb_or_f_lin(emb)
    setattr(wideBigModel.embedding, 'weight', torch.nn.Parameter(wide_emb, requires_grad=True).to(device))
    # f_lin layer weight
    f_lin_weight = model.f_lin.weight
    wide_f_lin_weight = G_zero_emb_or_f_lin(f_lin_weight)
    setattr(wideBigModel.f_lin, 'weight', torch.nn.Parameter(wide_f_lin_weight, requires_grad=True).to(device))
    # f_lin layer bias
    if has_bias == True:
        f_lin_bias = model.f_lin.bias
        wide_f_lin_bias = G_zero_emb_or_f_lin(f_lin_bias)
        setattr(wideBigModel.f_lin, 'bias', torch.nn.Parameter(wide_f_lin_bias, requires_grad=True).to(device))

    
    ii_list = [1,2,3,4,1,2,3,4]
    for i in range(1, num_big_layers+1):
        ii = ii_list[i-1]
        small_block = getattr(model, f'block{ii}')
        wide_block = getattr(wideBigModel, f'block{i}')
    
        # Setting FFN Weights
        wide_ffn_weight = G_zero(small_block.ffn.weight)
        setattr(wide_block.ffn, 'weight', torch.nn.Parameter((wide_ffn_weight), requires_grad=True).to(device))
        if has_bias == True:
            wide_ffn_bias = G_zero(small_block.ffn.bias)
            setattr(wide_block.ffn, 'bias', torch.nn.Parameter((wide_ffn_bias), requires_grad=True).to(device))

        # Setting Norm Layers
        l_norm_1_weight = G_zero_1d(small_block.l_norm_1.weight)
        l_norm_2_weight = G_zero_1d(small_block.l_norm_2.weight)
        setattr(wide_block.l_norm_1, 'weight', torch.nn.Parameter((l_norm_1_weight), requires_grad=True).to(device))
        setattr(wide_block.l_norm_2, 'weight', torch.nn.Parameter((l_norm_2_weight), requires_grad=True).to(device))
        if has_bias == True:
            l_norm_1_bias = G_zero_1d(small_block.l_norm_1.bias)
            l_norm_2_bias = G_zero_1d(small_block.l_norm_2.bias)
            setattr(wide_block.l_norm_1, 'bias', torch.nn.Parameter((l_norm_1_bias), requires_grad=True).to(device))
            setattr(wide_block.l_norm_2, 'bias', torch.nn.Parameter((l_norm_2_bias), requires_grad=True).to(device))

        # Setting Multi-Head Attention
        wide_mh_lin_weight = G_zero(small_block.multihead.mh_lin.weight)
        setattr(wide_block.multihead.mh_lin, 'weight', torch.nn.Parameter((wide_mh_lin_weight), requires_grad=True))

        for h in range(1, multi_heads+1):
            head = getattr(wide_block.multihead, f'head{h}')
            small_head = getattr(small_block.multihead, f'head{h}')

            new_k = G_zero(small_head.k.weight)
            new_q = G_zero(small_head.q.weight)
            new_v = G_zero(small_head.v.weight)
            setattr(head.k, 'weight', torch.nn.Parameter((new_k), requires_grad=True).to(device))
            setattr(head.q, 'weight', torch.nn.Parameter((new_q), requires_grad=True).to(device))
            setattr(head.v, 'weight', torch.nn.Parameter((new_v), requires_grad=True).to(device))

In [32]:
initWideExtendGZero_doubleStack()

In [33]:
optim_b = torch.optim.Adam(params=wideBigModel.parameters(), lr = 1e-3)
train(wideBigModel, optim_b, 1000)

2.366074800491333
Validation: 2.2506210803985596
2.2922167778015137
Validation: 2.2105813026428223
2.255974292755127
Validation: 2.168656349182129
2.194500207901001
Validation: 2.122814178466797
2.181781053543091
Validation: 2.1370317935943604
2.1698203086853027
Validation: 2.1307668685913086
2.1407856941223145
Validation: 2.089634656906128
2.1276755332946777
Validation: 2.097020149230957
2.1221821308135986
Validation: 2.0919740200042725
2.0993220806121826
Validation: 2.08664870262146


In [34]:
wideBigModel = BigModel().to(device)

In [35]:
@torch.no_grad()
def initWideExtendGZero_doubleStac_v2(has_bias=False):
    # Embedding layer
    emb = model.embedding.weight
    wide_emb = G_zero_emb_or_f_lin(emb)
    setattr(wideBigModel.embedding, 'weight', torch.nn.Parameter(wide_emb, requires_grad=True).to(device))
    # f_lin layer weight
    f_lin_weight = model.f_lin.weight
    wide_f_lin_weight = G_zero_emb_or_f_lin(f_lin_weight)
    setattr(wideBigModel.f_lin, 'weight', torch.nn.Parameter(wide_f_lin_weight, requires_grad=True).to(device))
    # f_lin layer bias
    if has_bias == True:
        f_lin_bias = model.f_lin.bias
        wide_f_lin_bias = G_zero_emb_or_f_lin(f_lin_bias)
        setattr(wideBigModel.f_lin, 'bias', torch.nn.Parameter(wide_f_lin_bias, requires_grad=True).to(device))

    
    ii_list = [1,2,3,4,1,2,3,4]
    for i in range(1, num_big_layers+1):
        ii = ii_list[i-1]
        small_block = getattr(model, f'block{ii}')
        wide_block = getattr(wideBigModel, f'block{i}')
    
        # Setting FFN Weights
        wide_ffn_weight = G_zero_v2(small_block.ffn.weight)
        setattr(wide_block.ffn, 'weight', torch.nn.Parameter((wide_ffn_weight), requires_grad=True).to(device))
        if has_bias == True:
            wide_ffn_bias = G_zero_v2(small_block.ffn.bias)
            setattr(wide_block.ffn, 'bias', torch.nn.Parameter((wide_ffn_bias), requires_grad=True).to(device))

        # Setting Norm Layers
        l_norm_1_weight = G_zero_1d(small_block.l_norm_1.weight)
        l_norm_2_weight = G_zero_1d(small_block.l_norm_2.weight)
        setattr(wide_block.l_norm_1, 'weight', torch.nn.Parameter((l_norm_1_weight), requires_grad=True).to(device))
        setattr(wide_block.l_norm_2, 'weight', torch.nn.Parameter((l_norm_2_weight), requires_grad=True).to(device))
        if has_bias == True:
            l_norm_1_bias = G_zero_1d(small_block.l_norm_1.bias)
            l_norm_2_bias = G_zero_1d(small_block.l_norm_2.bias)
            setattr(wide_block.l_norm_1, 'bias', torch.nn.Parameter((l_norm_1_bias), requires_grad=True).to(device))
            setattr(wide_block.l_norm_2, 'bias', torch.nn.Parameter((l_norm_2_bias), requires_grad=True).to(device))

        # Setting Multi-Head Attention
        wide_mh_lin_weight = G_zero_v2(small_block.multihead.mh_lin.weight)
        setattr(wide_block.multihead.mh_lin, 'weight', torch.nn.Parameter((wide_mh_lin_weight), requires_grad=True))

        for h in range(1, multi_heads+1):
            head = getattr(wide_block.multihead, f'head{h}')
            small_head = getattr(small_block.multihead, f'head{h}')

            new_k = G_zero_v2(small_head.k.weight)
            new_q = G_zero_v2(small_head.q.weight)
            new_v = G_zero_v2(small_head.v.weight)
            setattr(head.k, 'weight', torch.nn.Parameter((new_k), requires_grad=True).to(device))
            setattr(head.q, 'weight', torch.nn.Parameter((new_q), requires_grad=True).to(device))
            setattr(head.v, 'weight', torch.nn.Parameter((new_v), requires_grad=True).to(device))

In [36]:
initWideExtendGZero_doubleStac_v2()

In [37]:
optim_b = torch.optim.Adam(params=wideBigModel.parameters(), lr = 1e-3)
train(wideBigModel, optim_b, 1000)

2.3297884464263916
Validation: 2.2664928436279297
2.2681915760040283
Validation: 2.2018697261810303
2.207615852355957
Validation: 2.1389591693878174
2.1661393642425537
Validation: 2.1297335624694824
2.1508235931396484
Validation: 2.1076908111572266
2.130034923553467
Validation: 2.1098811626434326
2.1205713748931885
Validation: 2.0871658325195312
2.1106972694396973
Validation: 2.0801334381103516
2.09464955329895
Validation: 2.0674517154693604
2.058530330657959
Validation: 2.055280923843384


In [39]:
## G Zero Cross stack -----------------------------------------------------------------------------------------------------------------------------------------------
## 1 0
## 0 0

In [40]:
wideBigModel = BigModel().to(device)

In [41]:
@torch.no_grad()
def initWideExtendGZero_crossStack(has_bias=False):
    # Embedding layer
    emb = model.embedding.weight
    wide_emb = G_zero_emb_or_f_lin(emb)
    setattr(wideBigModel.embedding, 'weight', torch.nn.Parameter(wide_emb, requires_grad=True).to(device))
    # f_lin layer weight
    f_lin_weight = model.f_lin.weight
    wide_f_lin_weight = G_zero_emb_or_f_lin(f_lin_weight)
    setattr(wideBigModel.f_lin, 'weight', torch.nn.Parameter(wide_f_lin_weight, requires_grad=True).to(device))
    # f_lin layer bias
    if has_bias == True:
        f_lin_bias = model.f_lin.bias
        wide_f_lin_bias = G_zero_emb_or_f_lin(f_lin_bias)
        setattr(wideBigModel.f_lin, 'bias', torch.nn.Parameter(wide_f_lin_bias, requires_grad=True).to(device))

    
    bi_list = [1,1,2,2,3,3,4,4]
    for i in range(1, num_big_layers+1):
        bi = bi_list[i-1]
        small_block = getattr(model, f'block{bi}')
        wide_block = getattr(wideBigModel, f'block{i}')
    
        # Setting FFN Weights
        wide_ffn_weight = G_zero(small_block.ffn.weight)
        setattr(wide_block.ffn, 'weight', torch.nn.Parameter((wide_ffn_weight), requires_grad=True).to(device))
        if has_bias == True:
            wide_ffn_bias = G_zero(small_block.ffn.bias)
            setattr(wide_block.ffn, 'bias', torch.nn.Parameter((wide_ffn_bias), requires_grad=True).to(device))

        # Setting Norm Layers
        l_norm_1_weight = G_zero_1d(small_block.l_norm_1.weight)
        l_norm_2_weight = G_zero_1d(small_block.l_norm_2.weight)
        setattr(wide_block.l_norm_1, 'weight', torch.nn.Parameter((l_norm_1_weight), requires_grad=True).to(device))
        setattr(wide_block.l_norm_2, 'weight', torch.nn.Parameter((l_norm_2_weight), requires_grad=True).to(device))
        if has_bias == True:
            l_norm_1_bias = G_zero_1d(small_block.l_norm_1.bias)
            l_norm_2_bias = G_zero_1d(small_block.l_norm_2.bias)
            setattr(wide_block.l_norm_1, 'bias', torch.nn.Parameter((l_norm_1_bias), requires_grad=True).to(device))
            setattr(wide_block.l_norm_2, 'bias', torch.nn.Parameter((l_norm_2_bias), requires_grad=True).to(device))

        # Setting Multi-Head Attention
        wide_mh_lin_weight = G_zero(small_block.multihead.mh_lin.weight)
        setattr(wide_block.multihead.mh_lin, 'weight', torch.nn.Parameter((wide_mh_lin_weight), requires_grad=True))

        for h in range(1, multi_heads+1):
            head = getattr(wide_block.multihead, f'head{h}')
            small_head = getattr(small_block.multihead, f'head{h}')

            new_k = G_zero(small_head.k.weight)
            new_q = G_zero(small_head.q.weight)
            new_v = G_zero(small_head.v.weight)
            setattr(head.k, 'weight', torch.nn.Parameter((new_k), requires_grad=True).to(device))
            setattr(head.q, 'weight', torch.nn.Parameter((new_q), requires_grad=True).to(device))
            setattr(head.v, 'weight', torch.nn.Parameter((new_v), requires_grad=True).to(device))

In [42]:
initWideExtendGZero_crossStack()

optim_b = torch.optim.Adam(params=wideBigModel.parameters(), lr = 1e-3)
train(wideBigModel, optim_b, 1000)

2.3411056995391846
Validation: 2.2575109004974365
2.280829668045044
Validation: 2.2327897548675537
2.2365355491638184
Validation: 2.1618454456329346
2.217055320739746
Validation: 2.150111675262451
2.1955485343933105
Validation: 2.126737594604492
2.1572203636169434
Validation: 2.1443989276885986
2.1532928943634033
Validation: 2.117771863937378
2.138089656829834
Validation: 2.1158785820007324
2.1390902996063232
Validation: 2.098881244659424
2.1166787147521973
Validation: 2.091095209121704


In [43]:
## G Zero Cross stack V2-----------------------------------------------------------------------------------------------------------------------------------------------
## 1 0
## 0 1

In [44]:
wideBigModel = BigModel().to(device)

In [45]:
@torch.no_grad()
def initWideExtendGZeroV2_crossStack(has_bias=False):
    # Embedding layer
    emb = model.embedding.weight
    wide_emb = G_zero_emb_or_f_lin(emb)
    setattr(wideBigModel.embedding, 'weight', torch.nn.Parameter(wide_emb, requires_grad=True).to(device))
    # f_lin layer weight
    f_lin_weight = model.f_lin.weight
    wide_f_lin_weight = G_zero_emb_or_f_lin(f_lin_weight)
    setattr(wideBigModel.f_lin, 'weight', torch.nn.Parameter(wide_f_lin_weight, requires_grad=True).to(device))
    # f_lin layer bias
    if has_bias == True:
        f_lin_bias = model.f_lin.bias
        wide_f_lin_bias = G_zero_emb_or_f_lin(f_lin_bias)
        setattr(wideBigModel.f_lin, 'bias', torch.nn.Parameter(wide_f_lin_bias, requires_grad=True).to(device))

    
    bi_list = [1,1,2,2,3,3,4,4]
    for i in range(1, num_big_layers+1):
        bi = bi_list[i-1]
        small_block = getattr(model, f'block{bi}')
        wide_block = getattr(wideBigModel, f'block{i}')
    
        # Setting FFN Weights
        wide_ffn_weight = G_zero_v2(small_block.ffn.weight)
        setattr(wide_block.ffn, 'weight', torch.nn.Parameter((wide_ffn_weight), requires_grad=True).to(device))
        if has_bias == True:
            wide_ffn_bias = G_zero_v2(small_block.ffn.bias)
            setattr(wide_block.ffn, 'bias', torch.nn.Parameter((wide_ffn_bias), requires_grad=True).to(device))

        # Setting Norm Layers
        l_norm_1_weight = G_zero_1d(small_block.l_norm_1.weight)
        l_norm_2_weight = G_zero_1d(small_block.l_norm_2.weight)
        setattr(wide_block.l_norm_1, 'weight', torch.nn.Parameter((l_norm_1_weight), requires_grad=True).to(device))
        setattr(wide_block.l_norm_2, 'weight', torch.nn.Parameter((l_norm_2_weight), requires_grad=True).to(device))
        if has_bias == True:
            l_norm_1_bias = G_zero_1d(small_block.l_norm_1.bias)
            l_norm_2_bias = G_zero_1d(small_block.l_norm_2.bias)
            setattr(wide_block.l_norm_1, 'bias', torch.nn.Parameter((l_norm_1_bias), requires_grad=True).to(device))
            setattr(wide_block.l_norm_2, 'bias', torch.nn.Parameter((l_norm_2_bias), requires_grad=True).to(device))

        # Setting Multi-Head Attention
        wide_mh_lin_weight = G_zero_v2(small_block.multihead.mh_lin.weight)
        setattr(wide_block.multihead.mh_lin, 'weight', torch.nn.Parameter((wide_mh_lin_weight), requires_grad=True))

        for h in range(1, multi_heads+1):
            head = getattr(wide_block.multihead, f'head{h}')
            small_head = getattr(small_block.multihead, f'head{h}')

            new_k = G_zero_v2(small_head.k.weight)
            new_q = G_zero_v2(small_head.q.weight)
            new_v = G_zero_v2(small_head.v.weight)
            setattr(head.k, 'weight', torch.nn.Parameter((new_k), requires_grad=True).to(device))
            setattr(head.q, 'weight', torch.nn.Parameter((new_q), requires_grad=True).to(device))
            setattr(head.v, 'weight', torch.nn.Parameter((new_v), requires_grad=True).to(device))

In [46]:
initWideExtendGZeroV2_crossStack()

optim_b = torch.optim.Adam(params=wideBigModel.parameters(), lr = 1e-3)
train(wideBigModel, optim_b, 1000)

2.341594934463501
Validation: 2.2380168437957764
2.264225721359253
Validation: 2.213564157485962
2.228527307510376
Validation: 2.1478302478790283
2.1843695640563965
Validation: 2.108391046524048
2.1760964393615723
Validation: 2.1332359313964844
2.1275227069854736
Validation: 2.1187002658843994
2.1264758110046387
Validation: 2.0834903717041016
2.1154003143310547
Validation: 2.0648632049560547
2.0729153156280518
Validation: 2.0608694553375244
2.034621000289917
Validation: 2.0541322231292725


In [47]:
wideBigModel_scratch = BigModel().to(device)
optim_b = torch.optim.Adam(params=wideBigModel_scratch.parameters(), lr = 1e-3)
train(wideBigModel_scratch, optim_b, 1000)

2.4599757194519043
Validation: 2.371760845184326
2.2479116916656494
Validation: 2.2242133617401123
2.1852540969848633
Validation: 2.1416192054748535
2.120945930480957
Validation: 2.0941712856292725
2.0811314582824707
Validation: 2.0433788299560547
2.0192532539367676
Validation: 2.03027606010437
1.9738414287567139
Validation: 2.016104221343994
1.9750126600265503
Validation: 2.010023355484009
1.955073356628418
Validation: 2.0032286643981934
1.945717453956604
Validation: 1.9732179641723633
