In [None]:
# Colab setup
!wget -P dataset/PennTreeBank https://raw.githubusercontent.com/BrownFortress/NLU-2024-Labs/main/labs/dataset/PennTreeBank/ptb.test.txt
!wget -P dataset/PennTreeBank https://raw.githubusercontent.com/BrownFortress/NLU-2024-Labs/main/labs/dataset/PennTreeBank/ptb.valid.txt
!wget -P dataset/PennTreeBank https://raw.githubusercontent.com/BrownFortress/NLU-2024-Labs/main/labs/dataset/PennTreeBank/ptb.train.txt

# Imports
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from functools import partial
import torch.utils.data as data
from torch.utils.data import DataLoader
import math
import numpy as np

--2025-11-03 20:33:55--  https://raw.githubusercontent.com/BrownFortress/NLU-2024-Labs/main/labs/dataset/PennTreeBank/ptb.test.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 449945 (439K) [text/plain]
Saving to: ‘dataset/PennTreeBank/ptb.test.txt’


2025-11-03 20:33:55 (17.7 MB/s) - ‘dataset/PennTreeBank/ptb.test.txt’ saved [449945/449945]

--2025-11-03 20:33:55--  https://raw.githubusercontent.com/BrownFortress/NLU-2024-Labs/main/labs/dataset/PennTreeBank/ptb.valid.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 399782 (

In [None]:
#
class LM_RNN(nn.Module):
    def __init__(self, emb_size, hidden_size, output_size, pad_index=0, out_dropout=0.1,
                 emb_dropout=0.1, n_layers=1):
        super(LM_RNN, self).__init__()
        # Token ids to vectors, we will better see this in the next lab
        self.embedding = nn.Embedding(output_size, emb_size, padding_idx=pad_index)
        # Pytorch's RNN layer: https://pytorch.org/docs/stable/generated/torch.nn.RNN.html
        self.rnn = nn.RNN(emb_size, hidden_size, n_layers, bidirectional=False, batch_first=True)
        self.pad_token = pad_index
        # Linear layer to project the hidden layer to our output space
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, input_sequence):
        emb = self.embedding(input_sequence)
        rnn_out, _  = self.rnn(emb)
        output = self.output(rnn_out).permute(0,2,1)
        return output

class LM_LSTM(nn.Module):
    def __init__(self, emb_size, hidden_size, output_size, pad_index=0, out_dropout=0.1,
                 emb_dropout=0.1, n_layers=1):
        super(LM_LSTM, self).__init__()
        self.embedding = nn.Embedding(output_size, emb_size, padding_idx=pad_index)
        self.rnn = nn.LSTM(emb_size, hidden_size, n_layers, bidirectional=False, batch_first=True)  # LSTM instead of RNN
        self.pad_token = pad_index
        self.output = nn.Linear(hidden_size, output_size)
        self.output.weight = self.embedding.weight  # weight tying

    def forward(self, input_sequence):
        emb = self.embedding(input_sequence)
        rnn_out, _ = self.rnn(emb)  # returns (output, (h_n, c_n))
        output = self.output(rnn_out).permute(0, 2, 1)
        return output

class LM_LSTM_Dropout(nn.Module):
    def __init__(self, emb_size, hidden_size, output_size, pad_index=0,
                 emb_dropout=0.1, hid_dropout=0.1, n_layers=1):
        super(LM_LSTM_Dropout, self).__init__()
        self.embedding = nn.Embedding(output_size, emb_size, padding_idx=pad_index)
        self.emb_dropout = nn.Dropout(emb_dropout)          # Dropout after embeddings
        self.rnn = nn.LSTM(emb_size, hidden_size, n_layers,
                           batch_first=True, bidirectional=False)
        self.out_dropout = nn.Dropout(hid_dropout)          # Dropout before final linear
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, input_sequence):
        emb = self.embedding(input_sequence)
        emb = self.emb_dropout(emb)                         # Apply embedding dropout
        rnn_out, _ = self.rnn(emb)
        rnn_out = self.out_dropout(rnn_out)                # Apply output dropout
        output = self.output(rnn_out).permute(0, 2, 1)
        return output

class LockedDropout(nn.Module):
    def __init__(self):
        super(LockedDropout, self).__init__()

    def forward(self, x, dropout=0.5):
        if not self.training or dropout == 0:
            return x
        mask = x.new_empty(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
        mask = mask.div_(1 - dropout)
        mask = mask.expand_as(x)
        return x * mask

class LM_LSTM_VDO(nn.Module):
    def __init__(self, emb_size, hidden_size, output_size, pad_index,
                 emb_dropout=0.1, hid_dropout=0.1, n_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(output_size, emb_size, padding_idx=pad_index)
        self.lockdrop = LockedDropout()  # variational dropout
        self.emb_dropout = emb_dropout
        self.hid_dropout = hid_dropout

        self.rnn = nn.LSTM(emb_size, hidden_size, n_layers,
                           batch_first=True, bidirectional=False)

        self.output = nn.Linear(hidden_size, output_size, bias=False)
        self.output.weight = self.embedding.weight  # weight tying

    def forward(self, x):
        emb = self.embedding(x)
        emb = self.lockdrop(emb, self.emb_dropout)  # variational dropout on embeddings

        rnn_out, _ = self.rnn(emb)
        rnn_out = self.lockdrop(rnn_out, self.hid_dropout)  # variational dropout on hidden states

        out = self.output(rnn_out).permute(0, 2, 1)
        return out

# Loading the corpus

def read_file(path, eos_token="<eos>"):
    output = []
    with open(path, "r") as f:
        for line in f.readlines():
            output.append(line.strip() + " " + eos_token)
    return output

# Vocab with tokens to ids
def get_vocab(corpus, special_tokens=[]):
    output = {}
    i = 0
    for st in special_tokens:
        output[st] = i
        i += 1
    for sentence in corpus:
        for w in sentence.split():
            if w not in output:
                output[w] = i
                i += 1
    return output

# This class computes and stores our vocab
# Word to ids and ids to word
class Lang():
    def __init__(self, corpus, special_tokens=[]):
        self.word2id = self.get_vocab(corpus, special_tokens)
        self.id2word = {v:k for k, v in self.word2id.items()}
    def get_vocab(self, corpus, special_tokens=[]):
        output = {}
        i = 0
        for st in special_tokens:
            output[st] = i
            i += 1
        for sentence in corpus:
            for w in sentence.split():
                if w not in output:
                    output[w] = i
                    i += 1
        return output

class PennTreeBank (data.Dataset):
    # Mandatory methods are __init__, __len__ and __getitem__
    def __init__(self, corpus, lang):
        self.source = []
        self.target = []

        for sentence in corpus:
            self.source.append(sentence.split()[0:-1]) # We get from the first token till the second-last token
            self.target.append(sentence.split()[1:]) # We get from the second token till the last token
            # See example in section 6.2

        self.source_ids = self.mapping_seq(self.source, lang)
        self.target_ids = self.mapping_seq(self.target, lang)

    def __len__(self):
        return len(self.source)

    def __getitem__(self, idx):
        src= torch.LongTensor(self.source_ids[idx])
        trg = torch.LongTensor(self.target_ids[idx])
        sample = {'source': src, 'target': trg}
        return sample

    # Auxiliary methods

    def mapping_seq(self, data, lang): # Map sequences of tokens to corresponding computed in Lang class
        res = []
        for seq in data:
            tmp_seq = []
            for x in seq:
                if x in lang.word2id:
                    tmp_seq.append(lang.word2id[x])
                else:
                    print('OOV found!')
                    print('You have to deal with that') # PennTreeBank doesn't have OOV but "Trust is good, control is better!"
                    break
            res.append(tmp_seq)
        return res

def collate_fn(data, pad_token):
    def merge(sequences):
        '''
        merge from batch * sent_len to batch * max_len
        '''
        lengths = [len(seq) for seq in sequences]
        max_len = 1 if max(lengths)==0 else max(lengths)
        # Pad token is zero in our case
        # So we create a matrix full of PAD_TOKEN (i.e. 0) with the shape
        # batch_size X maximum length of a sequence
        padded_seqs = torch.LongTensor(len(sequences),max_len).fill_(pad_token)
        for i, seq in enumerate(sequences):
            end = lengths[i]
            padded_seqs[i, :end] = seq # We copy each sequence into the matrix
        padded_seqs = padded_seqs.detach()  # We remove these tensors from the computational graph
        return padded_seqs, lengths

    # Sort data by seq lengths
    data.sort(key=lambda x: len(x["source"]), reverse=True)
    new_item = {}
    for key in data[0].keys():
        new_item[key] = [d[key] for d in data]

    source, _ = merge(new_item["source"])
    target, lengths = merge(new_item["target"])

    new_item["source"] = source.to(DEVICE)
    new_item["target"] = target.to(DEVICE)
    new_item["number_tokens"] = sum(lengths)
    return new_item

def train_loop(data, optimizer, criterion, model, clip=5):
    model.train()
    loss_array = []
    number_of_tokens = []

    for sample in data:
        optimizer.zero_grad() # Zeroing the gradient
        output = model(sample['source'])
        loss = criterion(output, sample['target'])
        loss_array.append(loss.item() * sample["number_tokens"])
        number_of_tokens.append(sample["number_tokens"])
        loss.backward() # Compute the gradient, deleting the computational graph
        # clip the gradient to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step() # Update the weights

    return sum(loss_array)/sum(number_of_tokens)

def eval_loop(data, eval_criterion, model):
    model.eval()
    loss_to_return = []
    loss_array = []
    number_of_tokens = []
    # softmax = nn.Softmax(dim=1) # Use Softmax if you need the actual probability
    with torch.no_grad(): # It used to avoid the creation of computational graph
        for sample in data:
            output = model(sample['source'])
            loss = eval_criterion(output, sample['target'])
            loss_array.append(loss.item())
            number_of_tokens.append(sample["number_tokens"])

    ppl = math.exp(sum(loss_array) / sum(number_of_tokens))
    loss_to_return = sum(loss_array) / sum(number_of_tokens)
    return ppl, loss_to_return

def init_weights(mat):
    for m in mat.modules():
        if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:
            for name, param in m.named_parameters():
                if 'weight_ih' in name:
                    for idx in range(4):
                        mul = param.shape[0]//4
                        torch.nn.init.xavier_uniform_(param[idx*mul:(idx+1)*mul])
                elif 'weight_hh' in name:
                    for idx in range(4):
                        mul = param.shape[0]//4
                        torch.nn.init.orthogonal_(param[idx*mul:(idx+1)*mul])
                elif 'bias' in name:
                    param.data.fill_(0)
        else:
            if type(m) in [nn.Linear]:
                torch.nn.init.uniform_(m.weight, -0.01, 0.01)
                if m.bias != None:
                    m.bias.data.fill_(0.01)

def load_asgd_averaged_weights(model, optimizer):
    for group in optimizer.param_groups:
        for p in group['params']:
            state = optimizer.state[p]
            if 'ax' in state and state['ax'] is not None:
                p.data.copy_(state['ax'])



# 'TUNABLE' PARAMETERS
DEVICE = 'cuda:0'
TBS = 32
LR=3
HID_DIM=700
EMB_SIZE=700
n_epochs = 100
EMB_DO=0.4
OUT_DO=0.7

train_raw = read_file("dataset/PennTreeBank/ptb.train.txt")
dev_raw = read_file("dataset/PennTreeBank/ptb.valid.txt")
test_raw = read_file("dataset/PennTreeBank/ptb.test.txt")

clip = 5 # Clip the gradient
vocab = get_vocab(train_raw, ["<pad>", "<eos>"])
lang = Lang(train_raw, ["<pad>", "<eos>"])
vocab_len = len(lang.word2id)
#MODEL= LM_RNN(EMB_SIZE, HID_DIM, vocab_len, pad_index=lang.word2id["<pad>"]).to(DEVICE) # RNN MODEL
#MODEL = LM_LSTM(EMB_SIZE, HID_DIM, vocab_len, pad_index=lang.word2id["<pad>"]).to(DEVICE) # LSTM MODEL (with or without weight tying)
#MODEL = LM_LSTM_Dropout(EMB_SIZE, HID_DIM, vocab_len,pad_index=lang.word2id["<pad>"], emb_dropout=EMB_DO, hid_dropout=OUT_DO).to(DEVICE) # LSTM with DropOut MODEL
MODEL = LM_LSTM_VDO(EMB_SIZE, HID_DIM, vocab_len,pad_index=lang.word2id["<pad>"], emb_dropout=EMB_DO, hid_dropout=OUT_DO).to(DEVICE) # LSTM with Variational DropOut MODEL
model = MODEL
model.apply(init_weights)

#optimizer = optim.SGD(model.parameters(), lr=LR)  # SDG OPTIMIZER
#optimizer = torch.optim.AdamW(MODEL.parameters(), lr=LR, weight_decay=1e-2) # ADAMW OPTIMIZER
optimizer = torch.optim.ASGD(
    model.parameters(),
    lr=LR,
    t0=0,        # 0 = no averaging initially
    lambd=0.0,   # disable weight decay here, keep explicit reg if needed
    alpha=0.75   # ASGD averaging power (default works fine)
)

criterion_train = nn.CrossEntropyLoss(ignore_index=lang.word2id["<pad>"])
criterion_eval = nn.CrossEntropyLoss(ignore_index=lang.word2id["<pad>"], reduction='sum')

train_dataset = PennTreeBank(train_raw, lang)
dev_dataset = PennTreeBank(dev_raw, lang)
test_dataset = PennTreeBank(test_raw, lang)

# Main loop
train_loader = DataLoader(train_dataset, batch_size=TBS, collate_fn=partial(collate_fn, pad_token=lang.word2id["<pad>"]),  shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=128, collate_fn=partial(collate_fn, pad_token=lang.word2id["<pad>"]))
test_loader = DataLoader(test_dataset, batch_size=128, collate_fn=partial(collate_fn, pad_token=lang.word2id["<pad>"]))
patience = 3
losses_train = []
losses_dev = []
sampled_epochs = []
best_ppl = math.inf
best_model = None
triggered = False  # flag for NT-ASGD averaging

# Store averaged params outside optimizer.state to avoid PyTorch foreach grouping issues
avg_params = {}  # maps id(param) -> tensor(cpu)

pbar = tqdm(range(1, n_epochs))
for epoch in pbar:
    loss = train_loop(train_loader, optimizer, criterion_train, model, clip)

    if epoch % 1 == 0:
        sampled_epochs.append(epoch)
        losses_train.append(np.asarray(loss).mean())
        ppl_dev, loss_dev = eval_loop(dev_loader, criterion_eval, model)
        losses_dev.append(np.asarray(loss_dev).mean())
        pbar.set_description("PPL: %f" % ppl_dev)

        if ppl_dev < best_ppl:  # improvement
            best_ppl = ppl_dev
            best_model = copy.deepcopy(model).to('cpu')
            patience = 3
        else:
            patience -= 1

        # --- NT-ASGD trigger (SAFE: keep averages outside optimizer.state) ---
        if patience <= 0 and not triggered:
            print(">>> Triggering NT-ASGD averaging (safe mode)")

            # initialize averaged params in a separate dict (move to CPU to save GPU memory)
            for group in optimizer.param_groups:
                for p in group['params']:
                    avg_params[id(p)] = p.detach().cpu().clone()  # store CPU copy
                # ↓ reduce LR after averaging starts
                group['lr'] *= 0.33

            # If you want to keep t0 metadata, store it in param_group dict (optional)
            optimizer.param_groups[0]['t0'] = epoch
            triggered = True
            patience = 3  # reset patience to allow averaging to continue


        if patience <= 0 and triggered:  # early stop after averaging
            break
if triggered:
    print(">>> Loading NT-ASGD averaged weights into model (from avg_params)")
    # Copy averaged weights (stored on CPU) back into model parameters (on DEVICE)
    for group in optimizer.param_groups:
        for p in group['params']:
            key = id(p)
            if key in avg_params:
                p.data.copy_(avg_params[key].to(p.device))


best_model.to(DEVICE)
final_ppl,  _ = eval_loop(test_loader, criterion_eval, best_model)
print('Test ppl: ', final_ppl)

  0%|          | 0/99 [00:00<?, ?it/s]

>>> Triggering NT-ASGD averaging (safe mode)
>>> Loading NT-ASGD averaged weights into model (from avg_params)
Test ppl:  89.84049225798147


In [None]:
# To save the model
path = '/model_LSTM_NTAvSDG.pt'
torch.save(model.state_dict(), path)
# To load the model you need to initialize it
# model = LM_RNN(emb_size, hid_size, vocab_len, pad_index=lang.word2id["<pad>"]).to(device)
# Then you load it
# model.load_state_dict(torch.load(path))

In [None]:
# --- Rebuild the model architecture ---
loaded_model = LM_LSTM_VDO(
    EMB_SIZE,
    HID_DIM,
    vocab_len,
    pad_index=lang.word2id["<pad>"],
    emb_dropout=EMB_DO,
    hid_dropout=OUT_DO
).to(DEVICE)

# --- Load weights ---
loaded_model.load_state_dict(torch.load("/model_LSTM_NTAvSDG.pt", map_location=DEVICE))

# --- Evaluate on test set ---
loaded_model.eval()
final_ppl, final_loss = eval_loop(test_loader, criterion_eval, loaded_model)

print("Loaded model test PPL:", final_ppl)


Loaded model test PPL: 92.1442226468789
