### Translation Transformer Model:
- Following <a href="https://nlp.seas.harvard.edu/2018/04/03/attention.html" target="_blank">The Annotated Transformer</a>


In [None]:
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax
import copy
import math
tttt. . fdfd.      ssss sssssssss
import os
import pandas as pd
import time
import numpy as np

PRINT = True

### Data Preparation

### Model Architecture
- Encoder-decoder structure works the best for translation.  

- The encoder maps an 
    - input sequence of symbol representations $(x_1, ..., x_n)$ to 
    - a sequence of continuous representations $z = (z_1, ..., z_n)$.
- Given z, the decoder
    - generates an output sequence $(y_1, ..., y_m)$ of symbols one element a time.
    - at each step the model is auto-regressive, consuming the previously generated symbols as additional input when generating the next.


In [None]:
class EncoderDecoder(nn.Module):
    '''A standard Encoder-Decoder architecture. Base for this and many other models.'''
     
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        # python 3
        # super().__init__()
        # python 2
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
    
    # what are these masks for why do we need two masks?
    def forward(self, src, tgt, src_mask, tgt_mask):
        # src_mask (on encoder's input) is to make sure the inputs have the same length.
        # tgt_mask (on decoder's input) is to prevent the model to cheat by seeing the results (next word it's supposed to predict), zigzag zero padding.
        # NOTE: here it's calling the self-defined encode and decode functions, and switch the order of the parameters.
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
    
    def encode(self, src, src_mask):
        if PRINT:
            print(f'in EncoderDecoder before embedding(Encoder): src: {src.shape}; src_mask: {src_mask.shape}')
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        if PRINT:
            print(f'in EncoderDecoder before embedding(Decoder): tgt: {tgt.shape}; tgt_mask: {tgt_mask.shape}')
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

In [None]:
class Generator(nn.Module):
    ''' Standard linear + softmax generation step '''
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)
    
    def forward(self, x):
        return log_softmax(self.proj(x), dim=-1)

In [None]:
def clones(module, N):
    # Produce N identical layers
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = torch.softmax(scores, dim=-1)

    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


def subsequent_mask(size):
    attn_shape = (1, size, size)
    # diagonal=0: includes the main diagonal.
	# diagonal=1: starts one step above the main diagonal.
	# diagonal=-1: starts one step below the main diagonal.
    # torch.triu(..., diagonal=1) returns the upper triangular part
    # [[[0, 1, 1, 1, 1],
    #   [0, 0, 1, 1, 1],
    #   [0, 0, 0, 1, 1],
    #   [0, 0, 0, 0, 1],
    #   [0, 0, 0, 0, 0]]]
    # .type(torch.uint8) converts it to 0 (for allowed) and 1 (for masked)
    # Later when apply mask:
    # scores = scores.masked_fill(mask == 0, 0)   # replaces elements of the tensor where the condition is True with the given value 0.
    # scores = scores.masked_fill(mask == 1, -inf)  # replaces elements of the tensor where the condition is True with the given value -inf.
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)

    # Return a boolean mask tensor where each element is True if the corresponding element in subsequent_mask is 0, and False if it’s 1.
    # tensor([[[0, 1, 1],
    #          [0, 0, 1],
    #          [0, 0, 0]]], dtype=torch.uint8)
    # would return
    # tensor([[[ True, False, False],
    #          [ True,  True, False],
    #          [ True,  True,  True]]])
    return subsequent_mask == 0

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout = 0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        # NOTE: why 4 here?
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        if PRINT:
            print('In MHA:')
            for i, layer in enumerate(self.linears):
                print(f"Layer {i}: in_features = {layer.in_features}, out_features = {layer.out_features}")
            
        query, key, value = [
            l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for l, x in zip(self.linears, (query, key, value))
        ]

        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        x = (x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k))

        del query
        del key
        del value
        mha_return_val = self.linears[-1](x)
        if PRINT:
            print('MHA return val:', mha_return_val.shape)
        return mha_return_val

In [173]:
class LayerNorm(nn.Module):
    '''Construct a layernorm module (See citation for details).'''
    # TODO: what is the purpose of eps?
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [174]:
class SublayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        # NOTE: the sublayer is a lambda wrapper on the multi-head attention from encoder & decoder
        return x + self.dropout(sublayer(self.norm(x)))

In [175]:
class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn:MultiHeadedAttention, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Encoder(nn.Module):
    def __init__(self, layer:EncoderLayer, N):
        super(Encoder, self).__init__()
        # The encoder is composed of a stack of N identical layers.
        # repeat the entire Encoder layer (multi-head attention, feedforward, layer_norm...) N times.
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask):
        if PRINT:
            print(f'in Encoder after embedding: src:{x.shape}; src_mask: {mask.shape}')
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)


In [177]:
class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn:MultiHeadedAttention, src_attn:MultiHeadedAttention, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
    
    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

In [178]:
class Decoder(nn.Module):
    def __init__(self, layer:DecoderLayer, N):
        super(Decoder, self).__init__()
        # repeat the entire Decoder layer (2 multi-head attention, feedforward, layer_norm...) N times.
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        if PRINT:
            print(f'in Decoder after embedding: tgt: {x.shape}; encoder_output:{memory.shape}; encoder_output_mask: {src_mask.shape}; tgt_mask:{tgt_mask.shape}')
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [179]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        if PRINT:
            print(f'in FF: input shape: {x.shape}')
        ff_val = self.w_2(self.dropout(self.w_1(x).relu()))
        if PRINT:
            print(f'in FF: output shape: {x.shape}')
        return ff_val

### What does Embedding layer do:

- This embedding layer maps each token index to a d_model-dimensional vector.
- self.lut(x) Looks up embeddings for the token indices x. Then, scales the embeddings by a constant factor $$\sqrt{d_{model}}$$  
- This is a normalization trick to help with training stability. The dot-product    attention has a scaling factor $$\frac{1}{\sqrt{d_k}}$$ so scaling the input embeddings helps balance the magnitudes.  

- Without this scaling, the softmax in attention could produce very small gradients, especially at the beginning of training.

In [None]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        # This is the look up table, retrieving vectors using token IDs.
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        # (batch_size, seq_len) -> (batch_size, seq_len, d_model)
        return self.lut(x) * math.sqrt(self.d_model)

In [181]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [182]:
def make_model(src_vocab, tgt_vocab, N=6, d_model= 512, d_ff=2048, h=8, dropout=0.1):
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab)
    )
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return model

### Testing the model

In [183]:
def inference_test():
    test_model = make_model(11, 11, 2)
    test_model.eval()
    src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
    src_mask = torch.ones(1, 1, 10)

    memory = test_model.encode(src, src_mask)
    ys = torch.zeros(1, 1).type_as(src.data)

    for i in range(9):
        out = test_model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data))
        prob = test_model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, torch.empty(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    
    print("Example Untrained Model Prediction:", ys)

In [184]:
# for _ in range (10):
#     inference_test()

# Training

### Data preparation

In [None]:
from datasets import load_dataset

# billsum = load_dataset("billsum", split="ca_test")
# NOTE: using bigger training datasets.
billsum = load_dataset("billsum", split="train")

In [201]:
# Split data to train(0.8) & test(0.2)
billsum_split = billsum.train_test_split(test_size=0.2)
print('sample', billsum_split["train"][0])
print('training dataset size:', len(billsum_split['train']))
print('testing dataset size:', len(billsum_split['test']))

sample {'text': "SECTION 1. FAIRNESS AND ACCURACY IN HIGH STAKES EDUCATIONAL DECISIONS \n              FOR STUDENTS.\n\n    (a) Findings.--Congress makes the following findings:\n            (1) The use of large-scale achievement tests in education \n        has grown significantly in recent years. States and local \n        school districts have increasingly used these tests in such \n        contexts as raising student academic standards to make high-\n        stakes decisions with important consequences for individual \n        students, such as tracking (assigning students to schools, \n        programs, or classes based on achievement level), promotion of \n        students to the next grade, and graduation of students from \n        secondary school.\n            (2) The serious and often adverse consequences resulting \n        from the sole reliance on large-scale tests have increasingly \n        resulted in questions and significant concerns by students, \n        parents, te

### data tokenization

In [187]:
from transformers import AutoTokenizer

checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

prefix = "summarize: "

def preprocess(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=256, truncation=True)
    labels = tokenizer(text_target=examples["summary"], max_length=64, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [188]:
tokenized_billsum = billsum_split.map(preprocess, batched=True)

Map: 100%|██████████| 989/989 [00:02<00:00, 444.09 examples/s]
Map: 100%|██████████| 248/248 [00:00<00:00, 431.56 examples/s]


In [189]:
print('', tokenized_billsum.shape)
print('', tokenized_billsum['test'][0].keys())

 {'train': (989, 6), 'test': (248, 6)}
 dict_keys(['text', 'summary', 'title', 'input_ids', 'attention_mask', 'labels'])


In [190]:
print('', tokenized_billsum['test'][0]['input_ids'])
print('', tokenized_billsum['test'][0]['labels'])

 [21603, 10, 37, 151, 13, 8, 1015, 13, 1826, 103, 3, 35, 2708, 38, 6963, 10, 180, 3073, 9562, 1300, 5568, 460, 13606, 13, 8, 9836, 3636, 19, 21012, 12, 608, 10, 460, 13606, 5, 41, 9, 61, 242, 3659, 13, 48, 1375, 6, 8, 826, 4903, 7, 1581, 10, 5637, 105, 15291, 127, 1208, 364, 6152, 153, 598, 46, 5936, 53, 1745, 24, 8201, 28, 820, 42, 722, 12, 1899, 2765, 12, 1912, 364, 21, 8, 820, 42, 722, 11, 24, 1912, 7, 66, 13, 8, 826, 3621, 10, 41, 188, 61, 31663, 6203, 28, 820, 11, 722, 21, 4573, 224, 38, 8, 97, 11, 286, 213, 8, 364, 33, 12, 36, 937, 6, 8, 686, 13, 161, 6, 8, 464, 1124, 6, 11, 8, 463, 11, 594, 13, 8, 364, 5, 41, 279, 61, 30197, 7, 14023, 42, 3, 864, 7, 6732, 4128, 13, 2765, 6, 237, 3, 99, 2765, 7365, 8, 269, 12, 9460, 806, 14023, 5, 41, 254, 61, 419, 17, 13676, 8, 5015, 12, 12317, 42, 3, 864, 7, 6732, 3, 9, 10416, 12, 430, 1188, 42, 884, 116, 8, 10416, 19, 4187, 29452, 57, 3, 9, 806, 1188, 42, 884, 5, 41, 308, 61, 282, 6732, 7, 42, 3, 864, 7, 6732, 7, 2765, 12, 1912, 364, 21, 820, 

In [191]:
src_vocab = [item['text'].split() for item in tokenized_billsum['train']]
flattened_src = set([item for sublist in src_vocab for item in sublist])
print(len(flattened_src))

target_vocab = [item['text'].split() for item in tokenized_billsum['test']]
flattened_target = set([item for sublist in target_vocab for item in sublist])
print(len(flattened_target))


49246
22364


In [192]:
tokens = tokenizer.convert_ids_to_tokens(tokenized_billsum['train'][0]['input_ids'])
# tokens = tokenizer.convert_ids_to_tokens([0])
# Join tokens without any separator
joined_text = ''.join(tokens)
# Replace the marker with a space and strip any leading/trailing spaces
final_text = joined_text.replace('▁', ' ').strip()
print(final_text)

summarize: The people of the State of California do enact as follows: SECTION 1. Section 21107.8 of the Vehicle Code is amended to read: 21107.8. (a) (1) Any city or county may, by ordinance or resolution, find and declare that there are privately owned and maintained offstreet parking facilities as described in the ordinance or resolution within the city or county that are generally held open for use of the public for purposes of vehicular parking. Upon enactment by a city or county of the ordinance or resolution, Sections 22350, 23103, and 23109 and the provisions of Division 16.5 (commencing with Section 38000) shall apply to privately owned and maintained offstreet parking facilities, except as provided in subdivision (b). (2) (A) If a city or county enacts an ordinance or resolution authorized by paragraph (1), a city or county may include in that ordinance or resolution authorization for the operator of a privately owned and maintained offstreet parking facility to regulate unaut

In [193]:
# Batching the data
from torch.utils.data import DataLoader, TensorDataset
from torch.nn.utils.rnn import pad_sequence

input_tokens_src = [input_ids['input_ids'] for input_ids in tokenized_billsum['train']]
input_tokens_target = [label['labels'] for label in tokenized_billsum['train']]
print(len(input_tokens_src))
print(len(input_tokens_target))

# Convert each list into a tensor
src_tensors = [torch.tensor(seq, dtype=torch.long) for seq in input_tokens_src]
target_tensors = [torch.tensor(seq, dtype=torch.long) for seq in input_tokens_target]

# Pad the sequences (e.g., using padding_value=0, and setting batch_first=True)
padded_src = pad_sequence(src_tensors, batch_first=True, padding_value=0)
padded_target = pad_sequence(target_tensors, batch_first=True, padding_value=0)

# Create a TensorDataset using the padded tensors
train_dataset = TensorDataset(padded_src, padded_target)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
 

989
989


### Dummy inputs

In [194]:
# Vocabulary & Tokenizer
dummy_vocab = {
    "<pad>": 0,
    "<bos>": 1,
    "<eos>": 2,
    "the": 3,
    "cat": 4,
    "sat": 5,
    "on": 6,
    "mat": 7,
    "summary": 8,
    ".": 9
}

# tokenizer reverse
dummy_id2word = {v: k for k, v in dummy_vocab.items()}
print(dummy_id2word)

# input and expected summary
dummy_input_text = "the cat sat on the mat ."
dummy_summary_text = "cat on mat ."

# Tokenized
dummy_input_ids  = torch.tensor([1, 3, 4, 5, 6, 3, 7, 9, 2])  # <bos> the cat sat on the mat . <eos>
dummy_summary_ids = torch.tensor([1, 4, 6, 7, 9, 2])          # <bos> cat on mat . <eos>

# Define masks
dummy_src_mask = (dummy_input_ids != 0).unsqueeze(-2)

# decoder receives:
# At each decoding step i, it tries to predict tgt_output[i] from previous tgt_input[:i+1].
# i =     0,       1,         2
# in:    [1]    [1, 4]    [1, 4, 6]
# out:   [4]      [6]        [7]
dummy_tgt_input = [1, 4, 6, 7, 9]     # <bos> cat on mat .
dummy_tgt_output = [4, 6, 7, 9, 2]    # cat on mat . <eos>

{0: '<pad>', 1: '<bos>', 2: '<eos>', 3: 'the', 4: 'cat', 5: 'sat', 6: 'on', 7: 'mat', 8: 'summary', 9: '.'}


In each step:
•	Embeddings convert token IDs to vectors
•	Self-attention looks at previous tokens
•	Decoder cross-attends to encoder output (from the full input sentence)
•	Output logits are projected to vocab size and softmax gives probabilities
•	You take the argmax to get the next token

During decoding, the decoder looks at the encoder’s outputs (which represent the input sentence) using cross-attention.
This is how the decoder knows what the input was, so it can generate a relevant output.

### Training the system

In [195]:
def make_src_mask(src, pad_token=0):
    """
    src: Tensor of shape (batch_size, src_len)
    Returns: Tensor of shape (batch_size, 1, 1, src_len)
    
    src != pad_token
	•	This compares every token ID in the input tensor src to the padding token ID pad_token.
	•	Returns a boolean tensor of the same shape as src, where:
        •	True means the token is not padding (i.e. real input),
        •	False means the token is padding.
 
    •   Adds a new dimension at position -2 (i.e., the second-to-last axis).
	•	This reshapes the mask to fit what the attention mechanism expects.
    """
    src_mask = (src != pad_token).unsqueeze(-2)
    return src_mask  # shape: (batch_size, 1, 1, src_len)


def make_tgt_mask(tgt, pad_token=0):
    """
    tgt: Tensor of shape (batch_size, tgt_len)
    Returns: Tensor of shape (batch_size, 1, tgt_len, tgt_len)
    
    torch.tril(...)
	•	Applies a lower triangular mask:
	•	Keeps the values on the diagonal and below it.
	•	Sets everything above the diagonal to zero.

    Example (if tgt_len = 4):
    torch.tril(torch.ones(4, 4)) ➝ tensor([[1, 0, 0, 0],
                                            [1, 1, 0, 0],
                                            [1, 1, 1, 0],
                                            [1, 1, 1, 1]])
    """
    batch_size, tgt_len = tgt.size()
    # Padding mask
    tgt_pad_mask = (tgt != pad_token).unsqueeze(-2)  # (batch_size, 1, tgt_len)
    # subsequent mask used in the decoder to prevent “cheating” during training 
    # (i.e., to ensure that each token can only attend to itself and previous tokens, not future ones).
    # .bool() converts the matrix to boolean values: 1 -> True (can attend) 0 -> False (cannot attend)
    tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()  # (tgt_len, tgt_len)
    # To Combine masks their shapes must be broadcastable.
    # 	•	tgt_pad_mask: [batch_size, 1, tgt_len]
    # 	•	tgt_sub_mask: [tgt_len, tgt_len]
    # 	•	When combined with &, PyTorch broadcasts tgt_pad_mask to [batch_size, tgt_len, tgt_len]
    # Now, the final tgt_mask says: For each query token, which keys can it attend to? Only if:
    # 	•	the key token is not padding (tgt_pad_mask == True)
    # 	•	and it’s not in the future (tgt_sub_mask == True)
    tgt_mask = tgt_pad_mask & tgt_sub_mask  # (batch_size, 1, tgt_len, tgt_len)
    return tgt_mask

In [None]:
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# src_vocab_size = len(flattened_src)
# target_vocab_size = len(flattened_target)
# NOTE: update: Use the tokenizer vocab instead
src_vocab_size = tokenizer.vocab_size
target_vocab_size = tokenizer.vocab_size
# NOTE: update: increase the model layer
model = make_model(src_vocab_size, target_vocab_size, 4).to("mps" if torch.backends.mps.is_available() else "cpu")

# Loss & Optimizer
loss_fn = nn.NLLLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
epochs = 100

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    model.train()
    total_loss = 0

    for src, tgt in tqdm(train_loader):
        # print('src', src.shape)
        src_mask = make_src_mask(src, pad_token=0)
        
        # Shift target for decoder input/output
        # [nlp final project]
        # [nlp final]
        # [final project]
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]
        tgt_len = tgt_input.size(1)
        
        # Create padding mask: (batch_size, 1, 1, tgt_len)
        pad_mask = (tgt_input != 0).unsqueeze(-2) # pad_token = 0
        
        # Create look-ahead mask: (1, tgt_len, tgt_len)
        look_ahead_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()

        # Combine: (batch_size, 1, tgt_len, tgt_len)
        tgt_mask = pad_mask & look_ahead_mask.unsqueeze(0)

        # Move all to device
        src = src.to(device)
        tgt_input = tgt_input.to(device)
        tgt_output = tgt_output.to(device)
        src_mask = src_mask.to(device)
        tgt_mask = tgt_mask.to(device)
        
        if PRINT:
            print("src:", src.shape)
            print("tgt_input:", tgt_input.shape)
            print("src_mask:", src_mask.shape)
            print("tgt_mask:", tgt_mask.shape)

        assert tgt_mask.shape[-1] == tgt_input.shape[1], "Target mask length doesn't match input"
        assert src_mask.shape[-1] == src.shape[1], "Source mask length doesn't match input"
        
        out = model(src, tgt_input, src_mask, tgt_mask)
        if PRINT:
            print(f'EncoderDecoder output shape: {out.shape}')
        
        logits = model.generator(out)
        if PRINT:
            print(f'Transformer output shape: {logits.shape}')

        logits = logits.view(-1, logits.size(-1))
        tgt_output = tgt_output.contiguous().view(-1)

        loss = loss_fn(logits, tgt_output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        break

    avg_loss = total_loss / len(train_loader)
    print(f"Average loss: {avg_loss:.4f}")
    break



Epoch 1/100


  0%|          | 0/31 [00:00<?, ?it/s]

src: torch.Size([32, 256])
tgt_input: torch.Size([32, 63])
src_mask: torch.Size([32, 1, 256])
tgt_mask: torch.Size([32, 63, 63])
in EncoderDecoder before embedding(Encoder): src: torch.Size([32, 256]); src_mask: torch.Size([32, 1, 256])
in Encoder after embedding: src:torch.Size([32, 256, 512]); src_mask: torch.Size([32, 1, 256])
In MHA:
Layer 0: in_features = 512, out_features = 512
Layer 1: in_features = 512, out_features = 512
Layer 2: in_features = 512, out_features = 512
Layer 3: in_features = 512, out_features = 512
MHA return val: torch.Size([32, 256, 512])
in FF: input shape: torch.Size([32, 256, 512])
in FF: output shape: torch.Size([32, 256, 512])
in EncoderDecoder before embedding(Decoder): tgt: torch.Size([32, 63]); tgt_mask: torch.Size([32, 63, 63])
in Decoder after embedding: tgt: torch.Size([32, 63, 512]); encoder_output:torch.Size([32, 256, 512]); encoder_output_mask: torch.Size([32, 1, 256]); tgt_mask:torch.Size([32, 63, 63])
In MHA:
Layer 0: in_features = 512, out_fea

  0%|          | 0/31 [00:00<?, ?it/s]

Average loss: 0.3175





In [None]:
def greedy_decode(model, src, src_mask, max_len, start_symbol, tokenizer):
    # Encoding source input
    memory = model.encode(src, src_mask)

    # Initializing decoder input with start token
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src).to(src.device)

    for i in range(max_len - 1):
        tgt_mask = subsequent_mask(ys.size(1)).to(src.device).unsqueeze(1)

        # Decoding one step
        out = model.decode(memory, src_mask, ys, tgt_mask)

        # Getting log-probabilities from generator
        prob = model.generator(out[:, -1])
        next_word = torch.argmax(prob, dim=-1).unsqueeze(1)

        # Appending the predicted word to the decoder input
        ys = torch.cat([ys, next_word], dim=1)

        # Optional: stop if EOS token is predicted
        if next_word.item() == tokenizer.eos_token_id:
            break

    return ys


def generate_summary(model, input_text, tokenizer):
    model.eval()

    # Tokenize input abstract
    encoding = tokenizer("summarize: " + input_text,
                         return_tensors="pt",
                         max_length=512,
                         truncation=True,
                         padding="max_length")

    src = encoding["input_ids"].to(next(model.parameters()).device)
    src_mask = encoding["attention_mask"].unsqueeze(0).unsqueeze(0).to(src.device)

    # Decode
    # start_symbol = tokenizer.convert_tokens_to_ids("summarize")  # or use tokenizer.bos_token_id if available
    # NOTE: start with valid beginning
    start_symbol = tokenizer.pad_token_id if tokenizer.bos_token_id is None else tokenizer.bos_token_id
    decoded_ids = greedy_decode(model, src, src_mask, max_len=64, start_symbol=start_symbol, tokenizer=tokenizer)

    # Convert token IDs to text
    return tokenizer.decode(decoded_ids[0], skip_special_tokens=True)

In [198]:
# sample = 'Batching matters a ton for speed. We want to have very evenly divided batches, with absolutely minimal padding. To do this we have to hack a bit around the default torchtext batching. This code patches their default batching to make sure we search over enough sentences to find tight batches.'
# generated_title = generate_summary(model, sample, tokenizer)
# print("Generated Title:", generated_title)

### Save the model

In [None]:
# saving the entire model
torch.save(model, 'my_transformer_model.pth')

# saving the weights only
torch.save(model.state_dict(), 'model_weights.pth')

### Call saved model

In [None]:
input_text = 'Batching matters a ton for speed. We want to have very evenly divided batches, with absolutely minimal padding. To do this we have to hack a bit around the default torchtext batching. This code patches their default batching to make sure we search over enough sentences to find tight batches.'

In [None]:
# loading the model:
model_saved = torch.load('my_transformer_model.pth', weights_only=False)
model_saved.eval()  # Set to eval mode if using for inference
generated_summary = generate_summary(model_saved, input_text, tokenizer)
print("Generated summary:", generated_summary)

In [None]:
# loading the weights:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") 
src_vocab_size = tokenizer.vocab_size
target_vocab_size = tokenizer.vocab_size
model_saved = make_model(src_vocab_size, target_vocab_size, N=6, d_model=768, d_ff=3072, h=8, dropout=0.15).to("mps" if torch.backends.mps.is_available() else "cpu")
model_saved.load_state_dict(torch.load("model_weights_v3.pth", map_location=device))
model_saved.to(device)
model_saved.eval()
generated_summary = generate_summary(model_saved, input_text, tokenizer)
print("Generated summary:", generated_summary)